Xgboost: [jvm-packages] java.lang.NullPointerException: ml.dmlc.xgboost4j.java.Booster.predictμ—μ„œ null

에 λ§Œλ“  2020λ…„ 07μ›” 30일  Β·  37μ½”λ©˜νŠΈ  Β·  좜처: dmlc/xgboost

NPE μ˜ˆμ™ΈλŠ” JAVA APIλ₯Ό 톡해 μ˜ˆμΈ‘ν•  λ•Œ λ°œμƒν•©λ‹ˆλ‹€.

java.lang.NullPointerException: null
ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:309)μ—μ„œ
ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:375)
com.tuhu.predict.predict.BaseModelPredict.predict(BaseModelPredict.java:71)
com.tuhu.predict.predict.XgboostFindPageModelPredict.predict(XgboostFindPageModelPredict.java:53)
com.tuhu.predict.service.impl.MlpFindPageFeatureServiceImpl.featureProcess(MlpFindPageFeatureServiceImpl.java:65)
com.tuhu.predict.api.controller.MlpFindPageController.recommendPredict(MlpFindPageController.java:49)
com.tuhu.predict.api.controller.MlpFindPageController$$FastClassBySpringCGLIB$$f694b9ff.invoke()
org.springframework.cglib.proxy.MethodProxy.invoke(MethodProxy.java:204)
org.springframework.aop.framework.CglibAopProxy$CglibMethodInvocation.invokeJoinpoint(CglibAopProxy.java:746)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)μ—μ„œ
org.springframework.aop.framework.adapter.MethodBeforeAdviceInterceptor.invoke(MethodBeforeAdviceInterceptor.java:52)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μ—μ„œ
org.springframework.aop.aspectj.AspectJAfterAdvice.invoke(AspectJAfterAdvice.java:47)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μ—μ„œ
org.springframework.aop.framework.adapter.AfterReturningAdviceInterceptor.invoke(AfterReturningAdviceInterceptor.java:52)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μ—μ„œ
org.springframework.aop.aspectj.AspectJAfterThrowingAdvice.invoke(AspectJAfterThrowingAdvice.java:62)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μ—μ„œ
org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint.proceed(MethodInvocationProceedingJoinPoint.java:88)μ—μ„œ
com.tuhu.springcloud.common.annotation.AbstractControllerLogAspect.doAround(AbstractControllerLogAspect.java:104)
sun.reflect.NativeMethodAccessorImpl.invoke0μ—μ„œ(λ„€μ΄ν‹°λΈŒ λ©”μ„œλ“œ)
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)μ—μ„œ
sun.reflect.DelegatingMethodAccessorImpl.invokeμ—μ„œ(DelegatingMethodAccessorImpl.java:43)
java.lang.reflect.Method.invoke(Method.java:498)μ—μ„œ
org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethodWithGivenArgs(AbstractAspectJAdvice.java:644)
org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethod(AbstractAspectJAdvice.java:633)
org.springframework.aop.aspectj.AspectJAroundAdvice.invoke(AspectJAroundAdvice.java:70)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μ—μ„œ
org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:92)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:185)μ—μ„œ
org.springframework.aop.framework.CglibAopProxy$DynamicAdvisedInterceptor.intercept(CglibAopProxy.java:688)
com.tuhu.predict.api.controller.MlpFindPageController$$EnhancerBySpringCGLIB$$560ed775.recommendPredict()
sun.reflect.NativeMethodAccessorImpl.invoke0μ—μ„œ(λ„€μ΄ν‹°λΈŒ λ©”μ„œλ“œ)
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)μ—μ„œ
sun.reflect.DelegatingMethodAccessorImpl.invokeμ—μ„œ(DelegatingMethodAccessorImpl.java:43)
java.lang.reflect.Method.invoke(Method.java:498)μ—μ„œ
org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:209)
org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:136)
org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:102)
org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:877)
org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:783)
org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:991)μ—μ„œ
org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:925)μ—μ„œ
org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:974)
org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:877)μ—μ„œ
javax.servlet.http.HttpServlet.service(HttpServlet.java:661)μ—μ„œ
org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:851)μ—μ„œ
javax.servlet.http.HttpServlet.service(HttpServlet.java:742)μ—μ„œ
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:52)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
com.tuhu.soter.starter.filter.SoterDefaultFilter.doFilter(SoterDefaultFilter.java:79)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
com.tuhu.boot.logback.filter.LogFilter.doFilter(LogFilter.java:54)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:158)
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:126)
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.doFilterInternal(WebMvcMetricsFilter.java:111)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.boot.actuate.web.trace.servlet.HttpTraceFilter.doFilterInternal(HttpTraceFilter.java:90)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
com.tuhu.boot.common.filter.HeartbeatFilter.doFilter(HeartbeatFilter.java:42)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
com.tuhu.boot.common.filter.MDCFilter.doFilter(MDCFilter.java:47)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:99)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.web.filter.HttpPutFormContentFilter.doFilterInternal(HttpPutFormContentFilter.java:109)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.web.filter.HiddenHttpMethodFilter.doFilterInternal(HiddenHttpMethodFilter.java:93)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:200)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:198)
org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:96)
org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:496)
org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:140)
org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:81)
org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:87)
org.apache.catalina.valves.RemoteIpValve.invoke(RemoteIpValve.java:677)
org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:342)
org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:803)
org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:66)
org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:790)
org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1468)
org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)μ—μ„œ
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
java.lang.Thread.run(Thread.java:748)μ—μ„œ

κ°€μž₯ μœ μš©ν•œ λŒ“κΈ€

μ•Œκ² μŠ΅λ‹ˆλ‹€. λ‚΄μΌκΉŒμ§€ μ€€λΉ„ν•  수 μžˆμ„ 것 κ°™μŠ΅λ‹ˆλ‹€.

λͺ¨λ“  37 λŒ“κΈ€

λͺ¨λΈμ€ Python Sklearn을 톡해 ν›ˆλ ¨λ˜κΈ° λ•Œλ¬Έμ— λ‚˜μ€‘μ— λΉ„ν˜Έν™˜μ„±μ΄ λ°œμƒν•©λ‹ˆλ‹€. μ‹œκ°„μ„ μ ˆμ•½ν•˜κΈ° μœ„ν•΄ μ•Œκ³ λ¦¬μ¦˜ νŒ€μ€ Sklearnμ—μ„œ ν›ˆλ ¨λœ XGB λͺ¨λΈμ„ Python XgBoost νŒ¨ν‚€μ§€ μœ„λ‘œ ν•œ 계측 μ΄λ™ν–ˆμŠ΅λ‹ˆλ‹€. 그게 원인이 λ˜μ—ˆλŠ”μ§€ κΆκΈˆν•©λ‹ˆλ‹€.

image

μ–΄λ–€ λ²„μ „μ˜ XGBoostλ₯Ό μ‚¬μš©ν•˜κ³  μžˆμŠ΅λ‹ˆκΉŒ? 이전에 예츑이 μ‹€νŒ¨ν•˜κ³  빈 예츑 버퍼λ₯Ό 계속 μ‚¬μš©ν•  λ•Œ jvm νŒ¨ν‚€μ§€κ°€ μ˜ˆμ™Έλ₯Ό μ˜¬λ°”λ₯΄κ²Œ throwν•˜μ§€ μ•ŠλŠ” 버그λ₯Ό μˆ˜μ •ν–ˆμŠ΅λ‹ˆλ‹€.

μ–΄λ–€ λ²„μ „μ˜ XGBoostλ₯Ό μ‚¬μš©ν•˜κ³  μžˆμŠ΅λ‹ˆκΉŒ? 이전에 예츑이 μ‹€νŒ¨ν•˜κ³  빈 예츑 버퍼λ₯Ό 계속 μ‚¬μš©ν•  λ•Œ jvm νŒ¨ν‚€μ§€κ°€ μ˜ˆμ™Έλ₯Ό μ˜¬λ°”λ₯΄κ²Œ throwν•˜μ§€ μ•ŠλŠ” 버그λ₯Ό μˆ˜μ •ν–ˆμŠ΅λ‹ˆλ‹€.

νšŒμ‚¬ μ•Œκ³ λ¦¬μ¦˜ ν”Œλž«νΌ 버전 1.0을 μ‚¬μš©ν•˜κ³  있으며, 버전 ν˜Έν™˜μ„± 문제둜 μ•Œκ³ λ¦¬μ¦˜ ν”„λ‘œμ νŠΈ 버전 0.9.0을 μ‚¬μš©ν•˜κ³  μžˆμŠ΅λ‹ˆλ‹€. μ•Œκ³ λ¦¬μ¦˜ λ™λ£Œλ“€μ€ Python을 μ‚¬μš©ν•˜μ—¬ 1.0 λͺ¨λΈ νŒŒμΌμ„ 0.9.0으둜 λ³€ν™˜ν–ˆμŠ΅λ‹ˆλ‹€. 이 λ³€ν˜•μ— μ˜ν•œ 것인지 κΆκΈˆν•©λ‹ˆλ‹€.

1.2(https://github.com/dmlc/xgboost/issues/5734)λ₯Ό κΈ°λ‹€λ Έλ‹€κ°€ λ‹€μ‹œ μ‹œλ„ν•˜λŠ” 것이 μ’‹μŠ΅λ‹ˆλ‹€. 이 λ¦΄λ¦¬μŠ€μ—λŠ” λͺ‡ 가지 μ€‘μš”ν•œ 버그 μˆ˜μ • 사항이 μžˆμŠ΅λ‹ˆλ‹€. λ˜ν•œ μ˜ˆμΈ‘μ„ μœ„ν•΄ λ™μΌν•˜κ±°λ‚˜ μ΄ν›„μ˜ xgboost 버전을 μ‚¬μš©ν•˜λŠ” 것이 μ’‹μŠ΅λ‹ˆλ‹€. XGBoost의 λ°”μ΄λ„ˆλ¦¬ λͺ¨λΈμ€ 이전 버전과 ν˜Έν™˜λ˜λ―€λ‘œ μ•žμœΌλ‘œ JSON 기반 λͺ¨λΈμ„ ꢌμž₯ν•©λ‹ˆλ‹€.

1.2.0μ—μ„œλ„ 같은 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€. κ·Έλž˜μ„œ λ¬Έμ œλŠ” μ—¬μ „νžˆ 여기에 μžˆμŠ΅λ‹ˆλ‹€.

λ‚˜λŠ” λ˜ν•œ 같은 λ¬Έμ œκ°€ λ°œμƒν–ˆμŠ΅λ‹ˆλ‹€.
xgboost4jλ₯Ό μ‚¬μš©ν•˜μ—¬ λͺ¨λΈμ„ μƒμ„±ν–ˆμŠ΅λ‹ˆλ‹€.

ν•΄κ²° 방법이 μžˆμŠ΅λ‹ˆκΉŒ?

이것은 μ €μ—κ²Œ 큰 λ¬Έμ œμž…λ‹ˆλ‹€. 생산 μž‘μ—…μ— μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€.

@ranInc μ΅œμ‹  λ²„μ „μ˜ XGBoostλ₯Ό μ‚¬μš©ν•˜κ³  μžˆμŠ΅λ‹ˆκΉŒ? μ§€κΈˆκΉŒμ§€ μš°λ¦¬λŠ” 이 문제의 μ •ν™•ν•œ 원인을 μ•Œμ§€ λͺ»ν•©λ‹ˆλ‹€. μ΅œμ„ μ„ λ‹€ν•΄ ν•΄κ²°ν•΄ λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€. μ–Έμ œ λ¬Έμ œκ°€ 해결될지 보μž₯ν•  수 μ—†μœΌλ―€λ‘œ κ·Έ λ™μ•ˆ λŒ€μ•ˆμ„ 쑰사해 λ³΄μ‹œκΈ° λ°”λžλ‹ˆλ‹€.

@ranInc 당신은 우리(개발자)κ°€ 우리 μžμ‹ μ˜ κΈ°κ³„μ—μ„œ μ‹€ν–‰ν•  수 μžˆλŠ” μž‘μ€ 예제 ν”„λ‘œκ·Έλž¨μ„ μ œκ³΅ν•¨μœΌλ‘œμ¨ 우리λ₯Ό λ„μšΈ 수 μžˆμŠ΅λ‹ˆλ‹€.

maven μ €μž₯μ†Œμ˜ μ΅œμ‹  jar인 1.2.0을 μ‹€ν–‰ μ€‘μž…λ‹ˆλ‹€.
λ‚˜λ₯Ό μœ„ν•œ λŒ€μ•ˆμ€ saprk 2.4.5둜 λŒμ•„κ°€μ„œ xgboost 0.9λ₯Ό μ‚¬μš©ν•˜λŠ” 것보닀 - 그리고 이것이 λ‚΄κ°€ μ§€κΈˆ ν•˜κ³  μžˆλŠ” μΌμž…λ‹ˆλ‹€.

예λ₯Ό λ“€μ–΄: λ‚˜μ€‘μ— μž‘μ—… μ‹€νŒ¨μ˜ 원인이 λ˜λŠ” νŠΉμ • λͺ¨λΈ/데이터λ₯Ό μ°Ύμ•„λ‚΄κ² μŠ΅λ‹ˆλ‹€.

μ•ˆλ…•,
νŠΉμ • λͺ¨λΈ/데이터λ₯Ό μ°Ύμ•˜μŠ΅λ‹ˆλ‹€.
이 λŒ“κΈ€μ— μ²¨λΆ€ν•©λ‹ˆλ‹€.
xgb1.2_bug.zip

이것이 버그λ₯Ό μž¬μƒμ„±ν•˜λŠ” λ°©λ²•μž…λ‹ˆλ‹€(μ—¬κΈ°μ„œ λ‹€μ‹œ λΆ„ν•  ν•˜μ§€ μ•ŠμœΌλ©΄ μž‘λ™ν•©λ‹ˆλ‹€. λ”°λΌμ„œ 각 λΆ„ν•  μ˜μ—­μ˜ 데이터 μ–‘ λ˜λŠ” 데이터 μœ ν˜•κ³Ό 관련이 μžˆμŒμ„ λͺ…μ‹¬ν•˜μ‹­μ‹œμ˜€):

from pyspark.ml.pipeline import PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
df.count()

model = PipelineModel.read().load("/tmp/6620294785024229130_model_xg_only")

predictions = model.transform(df)

predictions.persist()
predictions.count()
predictions.show()

이 문제λ₯Ό μ–Έμ œ ν•΄κ²°ν•  수 μžˆλŠ”μ§€ μ•Œκ³  μžˆμŠ΅λ‹ˆκΉŒ?
이것은 λ‚΄κ°€ spark 3.0을 μ‚¬μš©ν•˜λŠ” 것을 λ°©μ§€ν•©λ‹ˆλ‹€ ...

@ranInc 아직 μ•„λ‹™λ‹ˆλ‹€. 버그 μˆ˜μ •μ΄ μ™„λ£Œλ˜λ©΄ μ•Œλ €λ“œλ¦¬κ² μŠ΅λ‹ˆλ‹€. λ˜ν•œ Scala에 μ½”λ“œλ₯Ό κ²Œμ‹œν•  수 μžˆμŠ΅λ‹ˆκΉŒ? μ €λŠ” μš°λ¦¬κ°€ XGBoost와 ν•¨κ»˜ PySpark의 μ‚¬μš©μ„ κ³΅μ‹μ μœΌλ‘œ μ§€μ›ν•œ 적이 μ—†λ‹€κ³  μƒκ°ν•©λ‹ˆλ‹€.

      import org.apache.spark.ml.{Pipeline, PipelineModel}
      val df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
      df.count()

      val model = PipelineModel.read.load("/tmp/6620294785024229130_model_xg_only")

      val predictions = model.transform(df)

      predictions.persist()
      predictions.count()
      predictions.show()

또 λ‹€λ₯Έ 포인터,
λ¬Έμ œλŠ” μ˜ˆμΈ‘ν•˜κΈ° μœ„ν•΄ μ „μ†‘λ˜λŠ” λͺ¨λ“  κΈ°λŠ₯이 0/λˆ„λ½μ΄κΈ° λ•Œλ¬ΈμΈ 것 κ°™μŠ΅λ‹ˆλ‹€.

아무도 이것에 λŒ€ν•΄ μž‘μ—…ν•˜μ§€ μ•ŠλŠ” 것 κ°™μ•„μš”?
이것은 기본적으둜 xgboostκ°€ spark 3μ—μ„œ μ „ν˜€ μž‘λ™ν•˜μ§€ μ•ŠλŠ”λ‹€λŠ” 것을 μ˜λ―Έν•©λ‹ˆλ‹€.

λ„€, μ£„μ†‘ν•©λ‹ˆλ‹€. μ§€κΈˆ 손이 꽉 μ°ΌμŠ΅λ‹ˆλ‹€. μš°λ¦¬λŠ” μ–΄λŠ μ‹œμ μ—μ„œ 이 문제λ₯Ό λ‹€λ£° κ²ƒμž…λ‹ˆλ‹€. κ·€ν•˜μ˜ μ–‘ν•΄λ₯Ό μ •μ€‘νžˆ λΆ€νƒλ“œλ¦½λ‹ˆλ‹€. 감사 ν•΄μš”.

@ranInc 였늘 μ‹œκ°„μ΄ μ’€ λ‚˜μ„œ μ—¬κΈ°μ—μ„œ μ œκ³΅ν•œ 슀크립트λ₯Ό μ‹€ν–‰ν•΄ λ³΄μ•˜μŠ΅λ‹ˆλ‹€. java.lang.NullPointerException 였λ₯˜λ₯Ό μž¬ν˜„ν–ˆμŠ΅λ‹ˆλ‹€.

μ΄μƒν•˜κ²Œλ„ μ΅œμ‹  개발 버전( master λΆ„κΈ°)은 같은 λ°©μ‹μœΌλ‘œ μΆ©λŒν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€. λŒ€μ‹  였λ₯˜κ°€ λ°œμƒν•©λ‹ˆλ‹€.

μŠ€λ ˆλ“œ "main" org.apache.spark.SparkException의 μ˜ˆμ™Έ: 단계 였λ₯˜λ‘œ 인해 μž‘μ—… 쀑단: 단계 7.0의 μž‘μ—… 0이 1회 μ‹€νŒ¨, κ°€μž₯ 졜근의 μ‹€νŒ¨: 단계 7.0의 μž‘μ—… 0.0 손싀(TID 11, d04389c5babb, μ‹€ν–‰κΈ° λ“œλΌμ΄λ²„): ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: 확인 μ‹€νŒ¨: Learner_model_param_.num_feature >= p_fmat->Info().num_col_(1 λŒ€ 2) : μ—΄ μˆ˜κ°€ λΆ€μŠ€ν„°μ˜ κΈ°λŠ₯ μˆ˜μ™€ μΌμΉ˜ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.

더 μ‘°μ‚¬ν•˜κ² μŠ΅λ‹ˆλ‹€.

였λ₯˜ λ©”μ‹œμ§€κ°€ 이제 μ˜λ―Έκ°€ μžˆλ‹€κ³  μƒκ°ν•©λ‹ˆλ‹€. μž…λ ₯에 예츑 λͺ¨λΈλ³΄λ‹€ 더 λ§Žμ€ κΈ°λŠ₯이 μžˆμŠ΅λ‹ˆλ‹€.

xgboost μ‹€νŒ¨ ν›„ jvm νŒ¨ν‚€μ§€κ°€ κ³„μ†λ˜κΈ° 전에 빈 예츑 버퍼가 μƒμ„±λ©λ‹ˆλ‹€. μ΅œκ·Όμ— 체크 κ°€λ“œλ₯Ό μΆ”κ°€ν–ˆμŠ΅λ‹ˆλ‹€.

ν›ˆλ ¨ 데이터 μ„ΈνŠΈμ˜ μ—΄ μˆ˜κ°€ 예츑 데이터 μ„ΈνŠΈλ³΄λ‹€ ν¬κ±°λ‚˜ 같은지 ν™•μΈν•˜μ‹­μ‹œμ˜€.

μ•ˆλ…•,
λͺ¨λΈμ€ λ™μΌν•œ μ–‘μ˜ κΈ°λŠ₯을 μ‚¬μš©ν•˜μ—¬ μƒμ„±λ˜μ—ˆμŠ΅λ‹ˆλ‹€.
μŠ€νŒŒν¬μ—μ„œλŠ” μ—¬λŸ¬ 열이 μ•„λ‹Œ ν•˜λ‚˜μ˜ 벑터 열을 μ‚¬μš©ν•©λ‹ˆλ‹€.
μ–΄μ¨Œλ“  λ²‘ν„°μ˜ ν¬κΈ°λŠ” 적합과 μ˜ˆμΈ‘μ„ μœ„ν•΄ 항상 λ™μΌν•©λ‹ˆλ‹€. 100% ν™•μ‹ ν•©λ‹ˆλ‹€.

이것은 λͺ¨λ“  0/λˆ„λ½λœ κΈ°λŠ₯이 μžˆλŠ” ν–‰κ³Ό 관련이 μžˆμŠ΅λ‹ˆλ‹€.
λͺ¨λ“  κΈ°λŠ₯이 0인 행을 데이터 ν”„λ ˆμž„μ—μ„œ ν•„ν„°λ§ν•˜λ©΄ μ œλŒ€λ‘œ μž‘λ™ν•¨μ„ μ•Œ 수 μžˆμŠ΅λ‹ˆλ‹€.

@ranInc λͺ¨λΈμ„ μƒμ„±ν•œ 전체 Scala ν”„λ‘œκ·Έλž¨μ„ κ²Œμ‹œν•  수 μžˆμŠ΅λ‹ˆκΉŒ? 였λ₯˜ λ©”μ‹œμ§€λŠ” λͺ¨λΈμ΄ 단일 κΈ°λŠ₯으둜 ν›ˆλ ¨λ˜μ—ˆμŒμ„ μ‹œμ‚¬ν•˜λŠ” 것 κ°™μŠ΅λ‹ˆλ‹€.

μ½”λ“œκ°€ 맀우 일반적이고 μ„ μ˜μ μΈ λ³€ν™˜κΈ°κ°€ 있기 λ•Œλ¬Έμ— λ§Žμ€ 도움이 될 것이라고 μƒκ°ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.
μ½”λ“œ μžμ²΄λŠ” λŒ€λΆ€λΆ„ μŠ€μΉΌλΌκ°€ μ•„λ‹Œ pysparkμž…λ‹ˆλ‹€.

κΈ°λŠ₯의 μˆ˜κ°€ λ¬Έμ œκ°€ μ•„λ‹˜μ„ ν™•μΈν•˜λŠ” κ°€μž₯ 쒋은 방법은 λͺ¨λ“  κΈ°λŠ₯이 0인 행을 ν•„ν„°λ§ν•˜κ³  λͺ¨λΈμ„ μ‚¬μš©ν•˜λŠ” κ²ƒμž…λ‹ˆλ‹€. μ΄λŠ” 문제 없이 μž‘λ™ν•©λ‹ˆλ‹€.
λ˜ν•œ λͺ¨λ“  행을 μœ μ§€ν•˜κ³  데이터 ν”„λ ˆμž„μ„ λ‹€μ‹œ λΆ„ν• ν•˜μ—¬ ν•˜λ‚˜μ˜ νŒŒν‹°μ…˜μ„ μ‚¬μš©ν•  수 있으며 그것도 μž‘λ™ν•©λ‹ˆλ‹€.

@ranInc 0으둜 행을 ν•„ν„°λ§ν–ˆμ§€λ§Œ μ—¬μ „νžˆ λ™μΌν•œ 였λ₯˜( java.lang.NullPointerException )에 μ§λ©΄ν–ˆμŠ΅λ‹ˆλ‹€.

...
df.na.drop(minNonNulls = 1)
...

이것이 μ˜¬λ°”λ₯Έ 방법이 μ•„λ‹Œκ°€?

μ½”λ“œκ°€ 맀우 일반적이고 일뢀 μ„ μ œμ μΈ λ³€ν™˜κΈ°κ°€ 있기 λ•Œλ¬Έμ— λ§Žμ€ 도움이 될 것이라고 μƒκ°ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.

ν•™μŠ΅ 및 예츑 μ‹œκ°„μ— μ–Όλ§ˆλ‚˜ λ§Žμ€ κΈ°λŠ₯이 μ‚¬μš©λ˜κ³  μžˆλŠ”μ§€ ν™•μΈν•˜κ³  μ‹ΆμŠ΅λ‹ˆλ‹€. 였λ₯˜ λ©”μ‹œμ§€

ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: 확인 μ‹€νŒ¨: Learner_model_param_.num_feature >= p_fmat->Info().num_col_(1 λŒ€ 2) : μ—΄ μˆ˜κ°€ λΆ€μŠ€ν„°μ˜ κΈ°λŠ₯ μˆ˜μ™€ μΌμΉ˜ν•˜μ§€ μ•ŠμŠ΅λ‹ˆλ‹€.

λͺ¨λΈμ΄ 단일 κΈ°λŠ₯으둜 ν•™μŠ΅λ˜μ—ˆκ³  예츑이 두 가지 κΈ°λŠ₯으둜 μˆ˜ν–‰λ˜κ³  μžˆμŒμ„ λ‚˜νƒ€λƒ…λ‹ˆλ‹€.

μ§€κΈˆμ€ μ—…λ‘œλ“œν•œ 데이터 ν”„λ ˆμž„κ³Ό μ§λ ¬ν™”λœ λͺ¨λΈμ—λ§Œ μ•‘μ„ΈμŠ€ν•  수 μžˆμŠ΅λ‹ˆλ‹€. λͺ¨λΈ κ΅μœ‘μ— 무엇이 λ“€μ–΄κ°”κ³  무엇이 잘λͺ»λ˜μ—ˆλŠ”지에 λŒ€ν•œ 톡찰λ ₯이 λΆ€μ‘±ν•˜μ—¬ 더 이상 문제λ₯Ό ν•΄κ²°ν•˜λŠ” 데 λ°©ν•΄κ°€ λ©λ‹ˆλ‹€. ν”„λ‘œκ·Έλž¨μ— 독점 정보가 μžˆλŠ” 경우 κΉ¨λ—ν•œ 예제λ₯Ό 생성할 수 μžˆμŠ΅λ‹ˆκΉŒ?

  1. μ•„λ‹ˆμ˜€, 당신은 이것을 ν•  수 μžˆμŠ΅λ‹ˆλ‹€:
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{callUDF, col}
.......
 val isVectorAllZeros = (col: Vector) => {
        col match {
          case sparse: SparseVector =>
            if(sparse.indices.isEmpty){
              true
            }else{
              false
            }
          case _ => false
        }
      }

      spark.udf.register("isVectorAllZeros", isVectorAllZeros)
      df = df.withColumn("isEmpty",callUDF("isVectorAllZeros",
        col("features_6620294785024229130"))).where("isEmpty == false")

λ‹€μŒκ³Ό 같이 데이터 ν”„λ ˆμž„μ„ λ‹€μ‹œ λΆ„ν• ν•  μˆ˜λ„ μžˆμŠ΅λ‹ˆλ‹€.

....
df = df.repartition(1)
....
  1. μ΄ν•΄ν•©λ‹ˆλ‹€. ν•˜μ§€λ§Œ 이 μ½”λ“œλŠ” VectorAssemblerλ₯Ό μ‚¬μš©ν•˜κΈ° λ•Œλ¬Έμ— λ§Žμ€ 정보λ₯Ό μ œκ³΅ν•˜μ§€ μ•ŠμœΌλ©° μ‹€μ œλ‘œ μ–Όλ§ˆλ‚˜ λ§Žμ€ κΈ°λŠ₯이 μ‚¬μš©λ˜μ—ˆλŠ”μ§€ μ•Œ 수 μ—†μŠ΅λ‹ˆλ‹€.
    ν•˜μ§€λ§Œ 같은 μ–‘μ˜ κΈ°λŠ₯을 μ‚¬μš©ν–ˆλ‹€κ³  100% ν™•μ‹ ν•©λ‹ˆλ‹€.

ν•˜μ§€λ§Œ 같은 μ–‘μ˜ κΈ°λŠ₯을 μ‚¬μš©ν–ˆλ‹€κ³  100% ν™•μ‹ ν•©λ‹ˆλ‹€.

VectorAssemblerκ°€ λ‹€μ–‘ν•œ 수의 κΈ°λŠ₯을 갖도둝 ν•˜λŠ” 경우 이λ₯Ό μ–΄λ–»κ²Œ 보μž₯ν–ˆμŠ΅λ‹ˆκΉŒ?

VectorAssemblerλŠ” 항상 λ™μΌν•œ μ–‘μ˜ κΈ°λŠ₯을 μƒμ„±ν•˜λ©° κ°€μ Έμ˜¬ μ—΄ μ΄λ¦„λ§Œ 있으면 λ©λ‹ˆλ‹€.
μ½”λ“œ μžμ²΄λŠ” 수천 개의 λͺ¨λΈμ„ λ§Œλ“œλŠ” 데 μ‚¬μš©λ˜λ―€λ‘œ 맀우 일반적이며 기본적으둜 μ‚¬μš©ν•  이름 λͺ©λ‘μ„ κ°€μ Έμ˜΅λ‹ˆλ‹€.

λͺ¨λΈ 생성을 λ‹€μ‹œ μ‹€ν–‰ν•˜κ³  λͺ¨λΈμ— μ‚¬μš©λœ 데이터 ν”„λ ˆμž„ λ˜λŠ” ν•„μš”ν•œ 기타 데이터λ₯Ό 보낼 수 μžˆμŠ΅λ‹ˆλ‹€.
ν•˜μ§€λ§Œ μ‹œκ°„μ΄ 걸리고 이전에 λ³΄μ—¬λ“œλ¦° 것을 μ‚¬μš©ν•˜λ©΄ λͺ¨λΈμ΄ 2가지 κΈ°λŠ₯으둜 잘 μž‘λ™ν•˜λŠ” 것을 λ³Ό 수 μžˆμŠ΅λ‹ˆλ‹€.

@ranInc ν•œ 가지 더 μ§ˆλ¬Έν•˜κ² μŠ΅λ‹ˆλ‹€. 예제 데이터 에 μ΅œλŒ€ 두 가지 κΈ°λŠ₯ 이 μžˆλŠ” ν¬μ†Œ μ—΄(VectorAssembler)이 μžˆλ‹€κ³  λ§ν•˜λŠ” 것이 λ§žμŠ΅λ‹ˆκΉŒ?

μ•„λ‹ˆμš”.
VectorAssemblerλŠ” μ—¬λŸ¬ 열을 가져와 ν•˜λ‚˜μ˜ Vector 열에 λ„£λŠ” Trasformerμž…λ‹ˆλ‹€.
λ²‘ν„°λŠ” 항상 μŠ€νŒŒν¬μ— μ ν•©ν•˜κ³  μ˜ˆμΈ‘ν•˜λŠ” λͺ¨λΈμ— μ‚¬μš©λ©λ‹ˆλ‹€.

μ—¬κΈ° 예제 데이터 ν”„λ ˆμž„μ—λŠ” 벑터 열이 μžˆμŠ΅λ‹ˆλ‹€.
일뢀 ν–‰μ—λŠ” ν¬μ†Œμ„±μ΄ 있고 λ‹€λ₯Έ ν–‰μ—λŠ” λ°€μ§‘λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€. λͺ¨λ‘ 두 가지 κΈ°λŠ₯을 κ°€μ§‘λ‹ˆλ‹€.

@ranInc λ”°λΌμ„œ λͺ¨λ“  ν–‰μ—λŠ” 두 가지 κΈ°λŠ₯이 있으며 일뢀 값은 λˆ„λ½λ˜κ³  λ‹€λ₯Έ 값은 λˆ„λ½λ©λ‹ˆλ‹€. μ•Œμ•˜λ‹€. 빈 ν–‰ 필터링에 λŒ€ν•œ μ œμ•ˆμ„ μ‹œλ„ν•΄ λ³΄κ² μŠ΅λ‹ˆλ‹€.

μ§μž‘ν•˜μ…¨κ² μ§€λ§Œ μ €λŠ” Spark μ—μ½”μ‹œμŠ€ν…œμ— μ΅μˆ™ν•˜μ§€ μ•ŠκΈ° λ•Œλ¬Έμ— 디버깅 λ…Έλ ₯이 μƒλ‹Ήνžˆ μ–΄λ €μšΈ 수 μžˆμŠ΅λ‹ˆλ‹€. μš°λ¦¬λŠ” ν˜„μž¬ 일반적으둜 Spark 및 Scala ν”„λ‘œκ·Έλž˜λ°μ— λŒ€ν•΄ 더 많이 μ•Œκ³  μžˆλŠ” 더 λ§Žμ€ κ°œλ°œμžκ°€ ν•„μš”ν•©λ‹ˆλ‹€. XGBoost의 JVM νŒ¨ν‚€μ§€λ₯Ό κ°œμ„ ν•˜λŠ” 데 도움을 μ£Όκ³  싢은 μ‚¬λžŒμ„ 개인적으둜 μ•Œκ³  κ³„μ‹œλ‹€λ©΄ μ•Œλ €μ£Όμ‹­μ‹œμ˜€.

@ranInc κ·€ν•˜μ˜ μ œμ•ˆμ— 따라 빈 행을 ν•„ν„°λ§ν•˜λ €κ³  μ‹œλ„ν–ˆμŠ΅λ‹ˆλ‹€.

ν”„λ‘œκ·Έλž¨ A: 빈 ν–‰ 필터링이 μ—†λŠ” 예제 슀크립트

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

ν”„λ‘œκ·Έλž¨ B: 빈 ν–‰ 필터링이 μžˆλŠ” 예

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val isVectorAllZeros = (col: Vector) => {
    col match {
      case sparse: SparseVector => (sparse.indices.isEmpty)
      case _ => false
    }
  }
  spark.udf.register("isVectorAllZeros", isVectorAllZeros)

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
                .withColumn("isEmpty", callUDF("isVectorAllZeros", col("features_6620294785024229130")))
                .where("isEmpty == false")
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

λͺ‡ 가지 κ΄€μ°°

  • μ•ˆμ •μ μΈ 1.2.0 λ¦΄λ¦¬μŠ€μ—μ„œλŠ” ν”„λ‘œκ·Έλž¨ A μ—μ„œ java.lang.NullPointerException 였λ₯˜κ°€ λ°œμƒν•©λ‹ˆλ‹€. NPE 직전에 Spark μ‹€ν–‰ λ‘œκ·Έμ— λ‹€μŒ κ²½κ³ κ°€ ν‘œμ‹œλ©λ‹ˆλ‹€.
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1
  • μ•ˆμ •μ μΈ 1.2.0 λ¦΄λ¦¬μŠ€μ—μ„œλŠ” ν”„λ‘œκ·Έλž¨ B κ°€ 였λ₯˜ 없이 μ„±κ³΅μ μœΌλ‘œ μ™„λ£Œλ©λ‹ˆλ‹€.
  • 개발 버전(μ΅œμ‹  master λΆ„κΈ°, 컀밋 42d31d9dcb6f7c1cb7d0545e9ab3a305ecad0816)을 μ‚¬μš©ν•˜λ©΄ ν”„λ‘œκ·Έλž¨ A 와 ν”„λ‘œκ·Έλž¨ B κ°€ λͺ¨λ‘ λ‹€μŒ 였λ₯˜μ™€ ν•¨κ»˜ μ‹€νŒ¨ν•©λ‹ˆλ‹€.
[12:44:57] /home/ubuntu/xgblatest/src/learner.cc:1179: Check failed: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster.                                                                                                        
Stack trace:                                                                                                                                   
  [bt] (0) /tmp/libxgboost4j14081654332866852928.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x79) [0x7f7ef62c4e19]                             [bt] (1) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::ValidateDMatrix(xgboost::DMatrix*, bool) const+0x20b) [0x7f7ef63f5f0b]                                                                                                                                              
  [bt] (2) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::Predict(std::shared_ptr<xgboost::DMatrix>, bool, xgboost::HostDeviceVector<float>*, unsigned int, bool, bool, bool, bool, bool)+0x3c3) [0x7f7ef6400233]                                                             
  [bt] (3) /tmp/libxgboost4j14081654332866852928.so(XGBoosterPredict+0xec) [0x7f7ef62caa3c]                                                      [bt] (4) /tmp/libxgboost4j14081654332866852928.so(Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict+0x47) [0x7f7ef62befd7]             
  [bt] (5) [0x7f80908a8270]

@ranInc 에 λ”°λ₯΄λ©΄ λͺ¨λΈμ€ 두 가지 κΈ°λŠ₯이 μžˆλŠ” λ°μ΄ν„°λ‘œ ν›ˆλ ¨λ˜μ—ˆκΈ° λ•Œλ¬Έμ— μ΄μƒν•©λ‹ˆλ‹€.

  • μ†ŒμŠ€μ—μ„œ 1.2.0-SNAPSHOT 버전을 λΉŒλ“œν–ˆμŠ΅λ‹ˆλ‹€(컀밋 71197d1dfa27c80add9954b10284848c1f165c40). μ΄λ²ˆμ—λŠ” ν”„λ‘œκ·Έλž¨ A 와 ν”„λ‘œκ·Έλž¨ B λͺ¨λ‘ κΈ°λŠ₯ 뢈일치 였λ₯˜( learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster )둜 μ‹€νŒ¨ν•©λ‹ˆλ‹€.
  • μ•ˆμ •μ μΈ 1.2.0 버전과 1.2.0-SNAPSHOT 의 λ™μž‘ μ°¨μ΄λŠ” μ˜ˆμƒν•˜μ§€ λͺ»ν–ˆλ˜ μΌμ΄μ—ˆκ³  μ €λ₯Ό μƒλ‹Ήνžˆ κΈ΄μž₯ν•˜κ²Œ λ§Œλ“€μ—ˆμŠ΅λ‹ˆλ‹€. 특히 1.2.0의 κ²½κ³  λ©”μ‹œμ§€
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1

1.2.0 λ²„μ „μ˜ C++ μ½”λ“œλ² μ΄μŠ€μ—λŠ” μ—†μŠ΅λ‹ˆλ‹€. λŒ€μ‹  release_1.0.0 λΆ„κΈ°μ—μ„œ κ²½κ³ λ₯Ό 찾을 수 μžˆμŠ΅λ‹ˆλ‹€.
https://github.com/dmlc/xgboost/blob/ea6b117a5737f5beb2533fc89b3f3fcd72ecc04e/src/learner.cc#L972 -L982
κ·Έλ ‡λ‹€λ©΄ Maven Central의 1.2.0 JAR νŒŒμΌμ— 1.0.0의 libxgboost4j.so κ°€ μžˆλ‹€λŠ” λœ»μΈκ°€μš”? 🀯 😱

  • μ‹€μ œλ‘œ Maven Central의 1.2.0 JAR νŒŒμΌμ—λŠ” μ‹€μ œλ‘œ 1.0.0(!!!)인 libxgboost4j.so κ°€ ν¬ν•¨λ˜μ–΄ μžˆμŠ΅λ‹ˆλ‹€. ν™•μΈν•˜λ €λ©΄ Maven Central μ—μ„œ xgboost4j_2.12-1.2.0.jar λ₯Ό λ‹€μš΄λ‘œλ“œν•˜κ³  libxgboost4j.so νŒŒμΌμ„ μΆ”μΆœν•˜μ‹­μ‹œμ˜€. 그런 λ‹€μŒ λ‹€μŒ Python 슀크립트λ₯Ό μ‹€ν–‰ν•˜μ—¬ 라이브러리 파일의 버전을 ν™•μΈν•©λ‹ˆλ‹€.
import ctypes

lib = ctypes.cdll.LoadLibrary('./libxgboost4j.so')

major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()

lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
print((major.value, minor.value, patch.value))  # prints (1, 0, 2), indicating version 1.0.2
  • 1.0.0 λ¬Έμ œλŠ” μ œμ³λ‘κ³  ν›ˆλ ¨λœ XGBoost λͺ¨λΈμ΄ 단일 κΈ°λŠ₯( learner_model_param_.num_feature == 1 )만 β€‹β€‹μΈμ‹ν–ˆμŒμ„ λΆ„λͺ…νžˆ μ•Œ 수 μžˆμŠ΅λ‹ˆλ‹€. ν›ˆλ ¨ 데이터에 100% λΉ„μ–΄ μžˆλŠ” κΈ°λŠ₯이 μžˆμ—ˆμ„κΉŒμš”? @ranInc

λͺ¨λΈμ„ λ§Œλ“œλŠ” 데 μ‚¬μš©λœ 데이터 ν”„λ ˆμž„μ„ κ°€μ Έμ˜€μ‹œκ² μŠ΅λ‹ˆκΉŒ?
μž‘μ„ 수만 μžˆλ‹€λ©΄ λͺ¨λΈμ„ μƒμ„±ν•˜λŠ” κ°„λ‹¨ν•œ 슀칼라 μ½”λ“œλ₯Ό λ§Œλ“€ 수 μžˆμ„ 것 κ°™μ•„μš”.

@ranInc λ‚΄ μ˜μ‹¬μ€ ν›ˆλ ¨ λ°μ΄ν„°μ˜ 두 κΈ°λŠ₯ 쀑 ν•˜λ‚˜κ°€ μ™„μ „νžˆ κ²°μΈ‘κ°’μœΌλ‘œ κ΅¬μ„±λ˜μ–΄ learner_model_param_.num_feature λ₯Ό 1둜 μ„€μ •ν–ˆλ‹€λŠ” κ²ƒμž…λ‹ˆλ‹€. 예, ν›ˆλ ¨ 데이터λ₯Ό λ³΄λŠ” 것이 맀우 도움이 될 κ²ƒμž…λ‹ˆλ‹€.

μ•Œκ² μŠ΅λ‹ˆλ‹€. λ‚΄μΌκΉŒμ§€ μ€€λΉ„ν•  수 μžˆμ„ 것 κ°™μŠ΅λ‹ˆλ‹€.

libxgboost4j.so 뢈일치 문제λ₯Ό μΆ”μ ν•˜κΈ° μœ„ν•΄ #6426을 λ§Œλ“€μ—ˆμŠ΅λ‹ˆλ‹€. μ—¬κΈ°(#5957) learner_model_param_.num_feature κ°€ 1둜 μ„€μ •λ˜λŠ” μ΄μœ μ— λŒ€ν•œ λ…Όμ˜λ₯Ό κ³„μ†ν•˜κ² μŠ΅λ‹ˆλ‹€.

잘λͺ»λœ 것 κ°™μŠ΅λ‹ˆλ‹€. ν›ˆλ ¨ λ°μ΄ν„°μ—λŠ” 결츑값이 μ—†μŠ΅λ‹ˆλ‹€.
μ—¬κΈ° 예제 μ½”λ“œμ—μ„œ μ‹€νŒ¨λ₯Ό μž¬ν˜„ν•˜κΈ° μœ„ν•΄ μž¬λΆ„ν• μ„ μ€‘κ³„ν•˜λŠ” λŒ€μ‹  μ˜ˆμΈ‘μ— ν•˜λ‚˜μ˜ ν–‰(κΈ°λŠ₯이 0개만 μžˆλŠ”)만 μ‚¬μš©ν–ˆμŠ΅λ‹ˆλ‹€.

features_creation.zip

import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.DataFrame

val df = spark.read.parquet("/tmp/6620294785024229130_only_features_creation").persist()
df.count()

val regressor = new XGBoostRegressor()
    .setFeaturesCol("features_6620294785024229130")
    .setLabelCol("label_6620294785024229130")
    .setPredictionCol("prediction")
    .setMissing(0.0F)
    .setMaxDepth(3)
    .setNumRound(100)
    .setNumWorkers(1)

val pipeline = new Pipeline().setStages(Array(regressor))
val model = pipeline.fit(df)

val pred = spark.read.parquet("/tmp/6620294785024229130_features").persist()
pred.count()
pred.where("account_code == 4011593987").show()
model.transform(pred.where("account_code == 4011593987")).show()
이 νŽ˜μ΄μ§€κ°€ 도움이 λ˜μ—ˆλ‚˜μš”?
0 / 5 - 0 λ“±κΈ‰