NPE μμΈλ JAVA APIλ₯Ό ν΅ν΄ μμΈ‘ν λ λ°μν©λλ€.
java.lang.NullPointerException: null
ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:309)μμ
ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:375)
com.tuhu.predict.predict.BaseModelPredict.predict(BaseModelPredict.java:71)
com.tuhu.predict.predict.XgboostFindPageModelPredict.predict(XgboostFindPageModelPredict.java:53)
com.tuhu.predict.service.impl.MlpFindPageFeatureServiceImpl.featureProcess(MlpFindPageFeatureServiceImpl.java:65)
com.tuhu.predict.api.controller.MlpFindPageController.recommendPredict(MlpFindPageController.java:49)
com.tuhu.predict.api.controller.MlpFindPageController$$FastClassBySpringCGLIB$$f694b9ff.invoke(
org.springframework.cglib.proxy.MethodProxy.invoke(MethodProxy.java:204)
org.springframework.aop.framework.CglibAopProxy$CglibMethodInvocation.invokeJoinpoint(CglibAopProxy.java:746)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)μμ
org.springframework.aop.framework.adapter.MethodBeforeAdviceInterceptor.invoke(MethodBeforeAdviceInterceptor.java:52)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μμ
org.springframework.aop.aspectj.AspectJAfterAdvice.invoke(AspectJAfterAdvice.java:47)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μμ
org.springframework.aop.framework.adapter.AfterReturningAdviceInterceptor.invoke(AfterReturningAdviceInterceptor.java:52)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μμ
org.springframework.aop.aspectj.AspectJAfterThrowingAdvice.invoke(AspectJAfterThrowingAdvice.java:62)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μμ
org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint.proceed(MethodInvocationProceedingJoinPoint.java:88)μμ
com.tuhu.springcloud.common.annotation.AbstractControllerLogAspect.doAround(AbstractControllerLogAspect.java:104)
sun.reflect.NativeMethodAccessorImpl.invoke0μμ(λ€μ΄ν°λΈ λ©μλ)
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)μμ
sun.reflect.DelegatingMethodAccessorImpl.invokeμμ(DelegatingMethodAccessorImpl.java:43)
java.lang.reflect.Method.invoke(Method.java:498)μμ
org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethodWithGivenArgs(AbstractAspectJAdvice.java:644)
org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethod(AbstractAspectJAdvice.java:633)
org.springframework.aop.aspectj.AspectJAroundAdvice.invoke(AspectJAroundAdvice.java:70)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)μμ
org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:92)
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:185)μμ
org.springframework.aop.framework.CglibAopProxy$DynamicAdvisedInterceptor.intercept(CglibAopProxy.java:688)
com.tuhu.predict.api.controller.MlpFindPageController$$EnhancerBySpringCGLIB$$560ed775.recommendPredict(
sun.reflect.NativeMethodAccessorImpl.invoke0μμ(λ€μ΄ν°λΈ λ©μλ)
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)μμ
sun.reflect.DelegatingMethodAccessorImpl.invokeμμ(DelegatingMethodAccessorImpl.java:43)
java.lang.reflect.Method.invoke(Method.java:498)μμ
org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:209)
org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:136)
org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:102)
org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:877)
org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:783)
org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:991)μμ
org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:925)μμ
org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:974)
org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:877)μμ
javax.servlet.http.HttpServlet.service(HttpServlet.java:661)μμ
org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:851)μμ
javax.servlet.http.HttpServlet.service(HttpServlet.java:742)μμ
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:52)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
com.tuhu.soter.starter.filter.SoterDefaultFilter.doFilter(SoterDefaultFilter.java:79)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
com.tuhu.boot.logback.filter.LogFilter.doFilter(LogFilter.java:54)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:158)
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:126)
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.doFilterInternal(WebMvcMetricsFilter.java:111)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.boot.actuate.web.trace.servlet.HttpTraceFilter.doFilterInternal(HttpTraceFilter.java:90)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
com.tuhu.boot.common.filter.HeartbeatFilter.doFilter(HeartbeatFilter.java:42)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
com.tuhu.boot.common.filter.MDCFilter.doFilter(MDCFilter.java:47)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:99)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.web.filter.HttpPutFormContentFilter.doFilterInternal(HttpPutFormContentFilter.java:109)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.web.filter.HiddenHttpMethodFilter.doFilterInternal(HiddenHttpMethodFilter.java:93)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:200)
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:198)
org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:96)
org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:496)
org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:140)
org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:81)
org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:87)
org.apache.catalina.valves.RemoteIpValve.invoke(RemoteIpValve.java:677)
org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:342)
org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:803)
org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:66)
org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:790)
org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1468)
org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)μμ
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
java.lang.Thread.run(Thread.java:748)μμ
λͺ¨λΈμ Python Sklearnμ ν΅ν΄ νλ ¨λκΈ° λλ¬Έμ λμ€μ λΉνΈνμ±μ΄ λ°μν©λλ€. μκ°μ μ μ½νκΈ° μν΄ μκ³ λ¦¬μ¦ νμ Sklearnμμ νλ ¨λ XGB λͺ¨λΈμ Python XgBoost ν¨ν€μ§ μλ‘ ν κ³μΈ΅ μ΄λνμ΅λλ€. κ·Έκ² μμΈμ΄ λμλμ§ κΆκΈν©λλ€.
μ΄λ€ λ²μ μ XGBoostλ₯Ό μ¬μ©νκ³ μμ΅λκΉ? μ΄μ μ μμΈ‘μ΄ μ€ν¨νκ³ λΉ μμΈ‘ λ²νΌλ₯Ό κ³μ μ¬μ©ν λ jvm ν¨ν€μ§κ° μμΈλ₯Ό μ¬λ°λ₯΄κ² throwνμ§ μλ λ²κ·Έλ₯Ό μμ νμ΅λλ€.
μ΄λ€ λ²μ μ XGBoostλ₯Ό μ¬μ©νκ³ μμ΅λκΉ? μ΄μ μ μμΈ‘μ΄ μ€ν¨νκ³ λΉ μμΈ‘ λ²νΌλ₯Ό κ³μ μ¬μ©ν λ jvm ν¨ν€μ§κ° μμΈλ₯Ό μ¬λ°λ₯΄κ² throwνμ§ μλ λ²κ·Έλ₯Ό μμ νμ΅λλ€.
νμ¬ μκ³ λ¦¬μ¦ νλ«νΌ λ²μ 1.0μ μ¬μ©νκ³ μμΌλ©°, λ²μ νΈνμ± λ¬Έμ λ‘ μκ³ λ¦¬μ¦ νλ‘μ νΈ λ²μ 0.9.0μ μ¬μ©νκ³ μμ΅λλ€. μκ³ λ¦¬μ¦ λλ£λ€μ Pythonμ μ¬μ©νμ¬ 1.0 λͺ¨λΈ νμΌμ 0.9.0μΌλ‘ λ³ννμ΅λλ€. μ΄ λ³νμ μν κ²μΈμ§ κΆκΈν©λλ€.
1.2(https://github.com/dmlc/xgboost/issues/5734)λ₯Ό κΈ°λ€λ Έλ€κ° λ€μ μλνλ κ²μ΄ μ’μ΅λλ€. μ΄ λ¦΄λ¦¬μ€μλ λͺ κ°μ§ μ€μν λ²κ·Έ μμ μ¬νμ΄ μμ΅λλ€. λν μμΈ‘μ μν΄ λμΌνκ±°λ μ΄νμ xgboost λ²μ μ μ¬μ©νλ κ²μ΄ μ’μ΅λλ€. XGBoostμ λ°μ΄λ리 λͺ¨λΈμ μ΄μ λ²μ κ³Ό νΈνλλ―λ‘ μμΌλ‘ JSON κΈ°λ° λͺ¨λΈμ κΆμ₯ν©λλ€.
1.2.0μμλ κ°μ λ¬Έμ κ° λ°μνμ΅λλ€. κ·Έλμ λ¬Έμ λ μ¬μ ν μ¬κΈ°μ μμ΅λλ€.
λλ λν κ°μ λ¬Έμ κ° λ°μνμ΅λλ€.
xgboost4jλ₯Ό μ¬μ©νμ¬ λͺ¨λΈμ μμ±νμ΅λλ€.
ν΄κ²° λ°©λ²μ΄ μμ΅λκΉ?
μ΄κ²μ μ μκ² ν° λ¬Έμ μ λλ€. μμ° μμ μ μ€ν¨νμ΅λλ€.
@ranInc μ΅μ λ²μ μ XGBoostλ₯Ό μ¬μ©νκ³ μμ΅λκΉ? μ§κΈκΉμ§ μ°λ¦¬λ μ΄ λ¬Έμ μ μ νν μμΈμ μμ§ λͺ»ν©λλ€. μ΅μ μ λ€ν΄ ν΄κ²°ν΄ λλ¦¬κ² μ΅λλ€. μΈμ λ¬Έμ κ° ν΄κ²°λ μ§ λ³΄μ₯ν μ μμΌλ―λ‘ κ·Έ λμ λμμ μ‘°μ¬ν΄ 보μκΈ° λ°λλλ€.
@ranInc λΉμ μ μ°λ¦¬(κ°λ°μ)κ° μ°λ¦¬ μμ μ κΈ°κ³μμ μ€νν μ μλ μμ μμ νλ‘κ·Έλ¨μ μ 곡ν¨μΌλ‘μ¨ μ°λ¦¬λ₯Ό λμΈ μ μμ΅λλ€.
maven μ μ₯μμ μ΅μ jarμΈ 1.2.0μ μ€ν μ€μ
λλ€.
λλ₯Ό μν λμμ saprk 2.4.5λ‘ λμκ°μ xgboost 0.9λ₯Ό μ¬μ©νλ κ²λ³΄λ€ - κ·Έλ¦¬κ³ μ΄κ²μ΄ λ΄κ° μ§κΈ νκ³ μλ μΌμ
λλ€.
μλ₯Ό λ€μ΄: λμ€μ μμ μ€ν¨μ μμΈμ΄ λλ νΉμ λͺ¨λΈ/λ°μ΄ν°λ₯Ό μ°Ύμλ΄κ² μ΅λλ€.
μλ
,
νΉμ λͺ¨λΈ/λ°μ΄ν°λ₯Ό μ°Ύμμ΅λλ€.
μ΄ λκΈμ 첨λΆν©λλ€.
xgb1.2_bug.zip
μ΄κ²μ΄ λ²κ·Έλ₯Ό μ¬μμ±νλ λ°©λ²μ λλ€(μ¬κΈ°μ λ€μ λΆν νμ§ μμΌλ©΄ μλν©λλ€. λ°λΌμ κ° λΆν μμμ λ°μ΄ν° μ λλ λ°μ΄ν° μ νκ³Ό κ΄λ ¨μ΄ μμμ λͺ μ¬νμμμ€):
from pyspark.ml.pipeline import PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame
df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
df.count()
model = PipelineModel.read().load("/tmp/6620294785024229130_model_xg_only")
predictions = model.transform(df)
predictions.persist()
predictions.count()
predictions.show()
μ΄ λ¬Έμ λ₯Ό μΈμ ν΄κ²°ν μ μλμ§ μκ³ μμ΅λκΉ?
μ΄κ²μ λ΄κ° spark 3.0μ μ¬μ©νλ κ²μ λ°©μ§ν©λλ€ ...
@ranInc μμ§ μλλλ€. λ²κ·Έ μμ μ΄ μλ£λλ©΄ μλ €λλ¦¬κ² μ΅λλ€. λν Scalaμ μ½λλ₯Ό κ²μν μ μμ΅λκΉ? μ λ μ°λ¦¬κ° XGBoostμ ν¨κ» PySparkμ μ¬μ©μ 곡μμ μΌλ‘ μ§μν μ μ΄ μλ€κ³ μκ°ν©λλ€.
import org.apache.spark.ml.{Pipeline, PipelineModel}
val df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
df.count()
val model = PipelineModel.read.load("/tmp/6620294785024229130_model_xg_only")
val predictions = model.transform(df)
predictions.persist()
predictions.count()
predictions.show()
λ λ€λ₯Έ ν¬μΈν°,
λ¬Έμ λ μμΈ‘νκΈ° μν΄ μ μ‘λλ λͺ¨λ κΈ°λ₯μ΄ 0/λλ½μ΄κΈ° λλ¬ΈμΈ κ² κ°μ΅λλ€.
μ무λ μ΄κ²μ λν΄ μμ
νμ§ μλ κ² κ°μμ?
μ΄κ²μ κΈ°λ³Έμ μΌλ‘ xgboostκ° spark 3μμ μ ν μλνμ§ μλλ€λ κ²μ μλ―Έν©λλ€.
λ€, μ£μ‘ν©λλ€. μ§κΈ μμ΄ κ½ μ°Όμ΅λλ€. μ°λ¦¬λ μ΄λ μμ μμ μ΄ λ¬Έμ λ₯Ό λ€λ£° κ²μ λλ€. κ·νμ μν΄λ₯Ό μ μ€ν λΆνλ립λλ€. κ°μ¬ ν΄μ.
@ranInc μ€λ μκ°μ΄ μ’ λμ μ¬κΈ°μμ μ 곡ν μ€ν¬λ¦½νΈλ₯Ό μ€νν΄ λ³΄μμ΅λλ€. java.lang.NullPointerException
μ€λ₯λ₯Ό μ¬ννμ΅λλ€.
μ΄μνκ²λ μ΅μ κ°λ° λ²μ ( master
λΆκΈ°)μ κ°μ λ°©μμΌλ‘ μΆ©λνμ§ μμ΅λλ€. λμ μ€λ₯κ° λ°μν©λλ€.
μ€λ λ "main" org.apache.spark.SparkExceptionμ μμΈ: λ¨κ³ μ€λ₯λ‘ μΈν΄ μμ μ€λ¨: λ¨κ³ 7.0μ μμ 0μ΄ 1ν μ€ν¨, κ°μ₯ μ΅κ·Όμ μ€ν¨: λ¨κ³ 7.0μ μμ 0.0 μμ€(TID 11, d04389c5babb, μ€νκΈ° λλΌμ΄λ²): ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: νμΈ μ€ν¨: Learner_model_param_.num_feature >= p_fmat->Info().num_col_(1 λ 2) : μ΄ μκ° λΆμ€ν°μ κΈ°λ₯ μμ μΌμΉνμ§ μμ΅λλ€.
λ μ‘°μ¬νκ² μ΅λλ€.
μ€λ₯ λ©μμ§κ° μ΄μ μλ―Έκ° μλ€κ³ μκ°ν©λλ€. μ λ ₯μ μμΈ‘ λͺ¨λΈλ³΄λ€ λ λ§μ κΈ°λ₯μ΄ μμ΅λλ€.
xgboost μ€ν¨ ν jvm ν¨ν€μ§κ° κ³μλκΈ° μ μ λΉ μμΈ‘ λ²νΌκ° μμ±λ©λλ€. μ΅κ·Όμ μ²΄ν¬ κ°λλ₯Ό μΆκ°νμ΅λλ€.
νλ ¨ λ°μ΄ν° μΈνΈμ μ΄ μκ° μμΈ‘ λ°μ΄ν° μΈνΈλ³΄λ€ ν¬κ±°λ κ°μμ§ νμΈνμμμ€.
μλ
,
λͺ¨λΈμ λμΌν μμ κΈ°λ₯μ μ¬μ©νμ¬ μμ±λμμ΅λλ€.
μ€νν¬μμλ μ¬λ¬ μ΄μ΄ μλ νλμ λ²‘ν° μ΄μ μ¬μ©ν©λλ€.
μ΄μ¨λ 벑ν°μ ν¬κΈ°λ μ ν©κ³Ό μμΈ‘μ μν΄ νμ λμΌν©λλ€. 100% νμ ν©λλ€.
μ΄κ²μ λͺ¨λ 0/λλ½λ κΈ°λ₯μ΄ μλ νκ³Ό κ΄λ ¨μ΄ μμ΅λλ€.
λͺ¨λ κΈ°λ₯μ΄ 0μΈ νμ λ°μ΄ν° νλ μμμ νν°λ§νλ©΄ μ λλ‘ μλν¨μ μ μ μμ΅λλ€.
@ranInc λͺ¨λΈμ μμ±ν μ 체 Scala νλ‘κ·Έλ¨μ κ²μν μ μμ΅λκΉ? μ€λ₯ λ©μμ§λ λͺ¨λΈμ΄ λ¨μΌ κΈ°λ₯μΌλ‘ νλ ¨λμμμ μμ¬νλ κ² κ°μ΅λλ€.
μ½λκ° λ§€μ° μΌλ°μ μ΄κ³ μ μμ μΈ λ³νκΈ°κ° μκΈ° λλ¬Έμ λ§μ λμμ΄ λ κ²μ΄λΌκ³ μκ°νμ§ μμ΅λλ€.
μ½λ μ체λ λλΆλΆ μ€μΉΌλΌκ° μλ pysparkμ
λλ€.
κΈ°λ₯μ μκ° λ¬Έμ κ° μλμ νμΈνλ κ°μ₯ μ’μ λ°©λ²μ λͺ¨λ κΈ°λ₯μ΄ 0μΈ νμ νν°λ§νκ³ λͺ¨λΈμ μ¬μ©νλ κ²μ
λλ€. μ΄λ λ¬Έμ μμ΄ μλν©λλ€.
λν λͺ¨λ νμ μ μ§νκ³ λ°μ΄ν° νλ μμ λ€μ λΆν νμ¬ νλμ νν°μ
μ μ¬μ©ν μ μμΌλ©° κ·Έκ²λ μλν©λλ€.
@ranInc 0μΌλ‘ νμ νν°λ§νμ§λ§ μ¬μ ν λμΌν μ€λ₯( java.lang.NullPointerException
)μ μ§λ©΄νμ΅λλ€.
...
df.na.drop(minNonNulls = 1)
...
μ΄κ²μ΄ μ¬λ°λ₯Έ λ°©λ²μ΄ μλκ°?
μ½λκ° λ§€μ° μΌλ°μ μ΄κ³ μΌλΆ μ μ μ μΈ λ³νκΈ°κ° μκΈ° λλ¬Έμ λ§μ λμμ΄ λ κ²μ΄λΌκ³ μκ°νμ§ μμ΅λλ€.
νμ΅ λ° μμΈ‘ μκ°μ μΌλ§λ λ§μ κΈ°λ₯μ΄ μ¬μ©λκ³ μλμ§ νμΈνκ³ μΆμ΅λλ€. μ€λ₯ λ©μμ§
ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: νμΈ μ€ν¨: Learner_model_param_.num_feature >= p_fmat->Info().num_col_(1 λ 2) : μ΄ μκ° λΆμ€ν°μ κΈ°λ₯ μμ μΌμΉνμ§ μμ΅λλ€.
λͺ¨λΈμ΄ λ¨μΌ κΈ°λ₯μΌλ‘ νμ΅λμκ³ μμΈ‘μ΄ λ κ°μ§ κΈ°λ₯μΌλ‘ μνλκ³ μμμ λνλ λλ€.
μ§κΈμ μ λ‘λν λ°μ΄ν° νλ μκ³Ό μ§λ ¬νλ λͺ¨λΈμλ§ μ‘μΈμ€ν μ μμ΅λλ€. λͺ¨λΈ κ΅μ‘μ 무μμ΄ λ€μ΄κ°κ³ 무μμ΄ μλͺ»λμλμ§μ λν ν΅μ°°λ ₯μ΄ λΆμ‘±νμ¬ λ μ΄μ λ¬Έμ λ₯Ό ν΄κ²°νλ λ° λ°©ν΄κ° λ©λλ€. νλ‘κ·Έλ¨μ λ μ μ λ³΄κ° μλ κ²½μ° κΉ¨λν μμ λ₯Ό μμ±ν μ μμ΅λκΉ?
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{callUDF, col}
.......
val isVectorAllZeros = (col: Vector) => {
col match {
case sparse: SparseVector =>
if(sparse.indices.isEmpty){
true
}else{
false
}
case _ => false
}
}
spark.udf.register("isVectorAllZeros", isVectorAllZeros)
df = df.withColumn("isEmpty",callUDF("isVectorAllZeros",
col("features_6620294785024229130"))).where("isEmpty == false")
λ€μκ³Ό κ°μ΄ λ°μ΄ν° νλ μμ λ€μ λΆν ν μλ μμ΅λλ€.
....
df = df.repartition(1)
....
νμ§λ§ κ°μ μμ κΈ°λ₯μ μ¬μ©νλ€κ³ 100% νμ ν©λλ€.
VectorAssemblerκ° λ€μν μμ κΈ°λ₯μ κ°λλ‘ νλ κ²½μ° μ΄λ₯Ό μ΄λ»κ² 보μ₯νμ΅λκΉ?
VectorAssemblerλ νμ λμΌν μμ κΈ°λ₯μ μμ±νλ©° κ°μ Έμ¬ μ΄ μ΄λ¦λ§ μμΌλ©΄ λ©λλ€.
μ½λ μ체λ μμ² κ°μ λͺ¨λΈμ λ§λλ λ° μ¬μ©λλ―λ‘ λ§€μ° μΌλ°μ μ΄λ©° κΈ°λ³Έμ μΌλ‘ μ¬μ©ν μ΄λ¦ λͺ©λ‘μ κ°μ Έμ΅λλ€.
λͺ¨λΈ μμ±μ λ€μ μ€ννκ³ λͺ¨λΈμ μ¬μ©λ λ°μ΄ν° νλ μ λλ νμν κΈ°ν λ°μ΄ν°λ₯Ό λ³΄λΌ μ μμ΅λλ€.
νμ§λ§ μκ°μ΄ κ±Έλ¦¬κ³ μ΄μ μ 보μ¬λλ¦° κ²μ μ¬μ©νλ©΄ λͺ¨λΈμ΄ 2κ°μ§ κΈ°λ₯μΌλ‘ μ μλνλ κ²μ λ³Ό μ μμ΅λλ€.
@ranInc ν κ°μ§ λ μ§λ¬Ένκ² μ΅λλ€. μμ λ°μ΄ν° μ μ΅λ λ κ°μ§ κΈ°λ₯ μ΄ μλ ν¬μ μ΄(VectorAssembler)μ΄ μλ€κ³ λ§νλ κ²μ΄ λ§μ΅λκΉ?
μλμ.
VectorAssemblerλ μ¬λ¬ μ΄μ κ°μ Έμ νλμ Vector μ΄μ λ£λ Trasformerμ
λλ€.
벑ν°λ νμ μ€νν¬μ μ ν©νκ³ μμΈ‘νλ λͺ¨λΈμ μ¬μ©λ©λλ€.
μ¬κΈ° μμ λ°μ΄ν° νλ μμλ λ²‘ν° μ΄μ΄ μμ΅λλ€.
μΌλΆ νμλ ν¬μμ±μ΄ μκ³ λ€λ₯Έ νμλ λ°μ§λμ΄ μμ΅λλ€. λͺ¨λ λ κ°μ§ κΈ°λ₯μ κ°μ§λλ€.
@ranInc λ°λΌμ λͺ¨λ νμλ λ κ°μ§ κΈ°λ₯μ΄ μμΌλ©° μΌλΆ κ°μ λλ½λκ³ λ€λ₯Έ κ°μ λλ½λ©λλ€. μμλ€. λΉ ν νν°λ§μ λν μ μμ μλν΄ λ³΄κ² μ΅λλ€.
μ§μνμ ¨κ² μ§λ§ μ λ Spark μμ½μμ€ν μ μ΅μνμ§ μκΈ° λλ¬Έμ λλ²κΉ λ Έλ ₯μ΄ μλΉν μ΄λ €μΈ μ μμ΅λλ€. μ°λ¦¬λ νμ¬ μΌλ°μ μΌλ‘ Spark λ° Scala νλ‘κ·Έλλ°μ λν΄ λ λ§μ΄ μκ³ μλ λ λ§μ κ°λ°μκ° νμν©λλ€. XGBoostμ JVM ν¨ν€μ§λ₯Ό κ°μ νλ λ° λμμ μ£Όκ³ μΆμ μ¬λμ κ°μΈμ μΌλ‘ μκ³ κ³μλ€λ©΄ μλ €μ£Όμμμ€.
@ranInc κ·νμ μ μμ λ°λΌ λΉ νμ νν°λ§νλ €κ³ μλνμ΅λλ€.
νλ‘κ·Έλ¨ A: λΉ ν νν°λ§μ΄ μλ μμ μ€ν¬λ¦½νΈ
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}
object Main extends App {
val spark = SparkSession
.builder()
.appName("XGBoost4J-Spark Pipeline Example")
.getOrCreate()
val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
df.show()
val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")
val predictions = model.transform(df)
predictions.persist()
predictions.count()
predictions.show()
}
νλ‘κ·Έλ¨ B: λΉ ν νν°λ§μ΄ μλ μ
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}
object Main extends App {
val spark = SparkSession
.builder()
.appName("XGBoost4J-Spark Pipeline Example")
.getOrCreate()
val isVectorAllZeros = (col: Vector) => {
col match {
case sparse: SparseVector => (sparse.indices.isEmpty)
case _ => false
}
}
spark.udf.register("isVectorAllZeros", isVectorAllZeros)
val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
.withColumn("isEmpty", callUDF("isVectorAllZeros", col("features_6620294785024229130")))
.where("isEmpty == false")
df.show()
val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")
val predictions = model.transform(df)
predictions.persist()
predictions.count()
predictions.show()
}
java.lang.NullPointerException
μ€λ₯κ° λ°μν©λλ€. NPE μ§μ μ Spark μ€ν λ‘κ·Έμ λ€μ κ²½κ³ κ° νμλ©λλ€.WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1
master
λΆκΈ°, μ»€λ° 42d31d9dcb6f7c1cb7d0545e9ab3a305ecad0816)μ μ¬μ©νλ©΄ νλ‘κ·Έλ¨ A μ νλ‘κ·Έλ¨ B κ° λͺ¨λ λ€μ μ€λ₯μ ν¨κ» μ€ν¨ν©λλ€.[12:44:57] /home/ubuntu/xgblatest/src/learner.cc:1179: Check failed: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster.
Stack trace:
[bt] (0) /tmp/libxgboost4j14081654332866852928.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x79) [0x7f7ef62c4e19] [bt] (1) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::ValidateDMatrix(xgboost::DMatrix*, bool) const+0x20b) [0x7f7ef63f5f0b]
[bt] (2) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::Predict(std::shared_ptr<xgboost::DMatrix>, bool, xgboost::HostDeviceVector<float>*, unsigned int, bool, bool, bool, bool, bool)+0x3c3) [0x7f7ef6400233]
[bt] (3) /tmp/libxgboost4j14081654332866852928.so(XGBoosterPredict+0xec) [0x7f7ef62caa3c] [bt] (4) /tmp/libxgboost4j14081654332866852928.so(Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict+0x47) [0x7f7ef62befd7]
[bt] (5) [0x7f80908a8270]
@ranInc μ λ°λ₯΄λ©΄ λͺ¨λΈμ λ κ°μ§ κΈ°λ₯μ΄ μλ λ°μ΄ν°λ‘ νλ ¨λμκΈ° λλ¬Έμ μ΄μν©λλ€.
1.2.0-SNAPSHOT
λ²μ μ λΉλνμ΅λλ€(μ»€λ° 71197d1dfa27c80add9954b10284848c1f165c40). μ΄λ²μλ νλ‘κ·Έλ¨ A μ νλ‘κ·Έλ¨ B λͺ¨λ κΈ°λ₯ λΆμΌμΉ μ€λ₯( learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster
)λ‘ μ€ν¨ν©λλ€.1.2.0-SNAPSHOT
μ λμ μ°¨μ΄λ μμνμ§ λͺ»νλ μΌμ΄μκ³ μ λ₯Ό μλΉν κΈ΄μ₯νκ² λ§λ€μμ΅λλ€. νΉν 1.2.0μ κ²½κ³ λ©μμ§WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1
1.2.0 λ²μ μ C++ μ½λλ² μ΄μ€μλ μμ΅λλ€. λμ release_1.0.0
λΆκΈ°μμ κ²½κ³ λ₯Ό μ°Ύμ μ μμ΅λλ€.
https://github.com/dmlc/xgboost/blob/ea6b117a5737f5beb2533fc89b3f3fcd72ecc04e/src/learner.cc#L972 -L982
κ·Έλ λ€λ©΄ Maven Centralμ 1.2.0 JAR νμΌμ 1.0.0μ libxgboost4j.so
κ° μλ€λ λ»μΈκ°μ? π€― π±
libxgboost4j.so
κ° ν¬ν¨λμ΄ μμ΅λλ€. νμΈνλ €λ©΄ Maven Central μμ xgboost4j_2.12-1.2.0.jar
λ₯Ό λ€μ΄λ‘λνκ³ libxgboost4j.so
νμΌμ μΆμΆνμμμ€. κ·Έλ° λ€μ λ€μ Python μ€ν¬λ¦½νΈλ₯Ό μ€ννμ¬ λΌμ΄λΈλ¬λ¦¬ νμΌμ λ²μ μ νμΈν©λλ€.import ctypes
lib = ctypes.cdll.LoadLibrary('./libxgboost4j.so')
major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()
lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
print((major.value, minor.value, patch.value)) # prints (1, 0, 2), indicating version 1.0.2
learner_model_param_.num_feature == 1
)λ§ ββμΈμνμμ λΆλͺ
ν μ μ μμ΅λλ€. νλ ¨ λ°μ΄ν°μ 100% λΉμ΄ μλ κΈ°λ₯μ΄ μμμκΉμ? @ranIncλͺ¨λΈμ λ§λλ λ° μ¬μ©λ λ°μ΄ν° νλ μμ κ°μ Έμ€μκ² μ΅λκΉ?
μ‘μ μλ§ μλ€λ©΄ λͺ¨λΈμ μμ±νλ κ°λ¨ν μ€μΉΌλΌ μ½λλ₯Ό λ§λ€ μ μμ κ² κ°μμ.
@ranInc λ΄ μμ¬μ νλ ¨ λ°μ΄ν°μ λ κΈ°λ₯ μ€ νλκ° μμ ν κ²°μΈ‘κ°μΌλ‘ ꡬμ±λμ΄ learner_model_param_.num_feature
λ₯Ό 1λ‘ μ€μ νλ€λ κ²μ
λλ€. μ, νλ ¨ λ°μ΄ν°λ₯Ό 보λ κ²μ΄ λ§€μ° λμμ΄ λ κ²μ
λλ€.
μκ² μ΅λλ€. λ΄μΌκΉμ§ μ€λΉν μ μμ κ² κ°μ΅λλ€.
libxgboost4j.so
λΆμΌμΉ λ¬Έμ λ₯Ό μΆμ νκΈ° μν΄ #6426μ λ§λ€μμ΅λλ€. μ¬κΈ°(#5957) learner_model_param_.num_feature
κ° 1λ‘ μ€μ λλ μ΄μ μ λν λ
Όμλ₯Ό κ³μνκ² μ΅λλ€.
μλͺ»λ κ² κ°μ΅λλ€. νλ ¨ λ°μ΄ν°μλ κ²°μΈ‘κ°μ΄ μμ΅λλ€.
μ¬κΈ° μμ μ½λμμ μ€ν¨λ₯Ό μ¬ννκΈ° μν΄ μ¬λΆν μ μ€κ³νλ λμ μμΈ‘μ νλμ ν(κΈ°λ₯μ΄ 0κ°λ§ μλ)λ§ μ¬μ©νμ΅λλ€.
import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.DataFrame
val df = spark.read.parquet("/tmp/6620294785024229130_only_features_creation").persist()
df.count()
val regressor = new XGBoostRegressor()
.setFeaturesCol("features_6620294785024229130")
.setLabelCol("label_6620294785024229130")
.setPredictionCol("prediction")
.setMissing(0.0F)
.setMaxDepth(3)
.setNumRound(100)
.setNumWorkers(1)
val pipeline = new Pipeline().setStages(Array(regressor))
val model = pipeline.fit(df)
val pred = spark.read.parquet("/tmp/6620294785024229130_features").persist()
pred.count()
pred.where("account_code == 4011593987").show()
model.transform(pred.where("account_code == 4011593987")).show()
κ°μ₯ μ μ©ν λκΈ
μκ² μ΅λλ€. λ΄μΌκΉμ§ μ€λΉν μ μμ κ² κ°μ΅λλ€.