Xgboost: [jvm-packages] java.lang.NullPointerException: 在 ml.dmlc.xgboost4j.java.Booster.predict 处为空

创建于 2020-07-30  ·  37评论  ·  资料来源: dmlc/xgboost

通过 JAVA API 预测时会发生 NPE 异常。

java.lang.NullPointerException:空
在 ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:309)
在 ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:375)
在 com.tuhu.predict.predict.BaseModelPredict.predict(BaseModelPredict.java:71)
在 com.tuhu.predict.predict.XgboostFindPageModelPredict.predict(XgboostFindPageModelPredict.java:53)
在 com.tuhu.predict.service.impl.MlpFindPageFeatureServiceImpl.featureProcess(MlpFindPageFeatureServiceImpl.java:65)
在 com.tuhu.predict.api.controller.MlpFindPageController.recommendPredict(MlpFindPageController.java:49)
在 com.tuhu.predict.api.controller.MlpFindPageController$$FastClassBySpringCGLIB$$f694b9ff.invoke()
在 org.springframework.cglib.proxy.MethodProxy.invoke(MethodProxy.java:204)
在 org.springframework.aop.framework.CglibAopProxy$CglibMethodInvocation.invokeJoinpoint(CglibAopProxy.java:746)
在 org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)
在 org.springframework.aop.framework.adapter.MethodBeforeAdviceInterceptor.invoke(MethodBeforeAdviceInterceptor.java:52)
在 org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
在 org.springframework.aop.aspectj.AspectJAfterAdvice.invoke(AspectJAfterAdvice.java:47)
在 org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
在 org.springframework.aop.framework.adapter.AfterReturningAdviceInterceptor.invoke(AfterReturningAdviceInterceptor.java:52)
在 org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
在 org.springframework.aop.aspectj.AspectJAfterThrowingAdvice.invoke(AspectJAfterThrowingAdvice.java:62)
在 org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
在 org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint.proceed(MethodInvocationProceedingJoinPoint.java:88)
在 com.tuhu.springcloud.common.annotation.AbstractControllerLogAspect.doAround(AbstractControllerLogAspect.java:104)
在 sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
在 sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
在 sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
在 java.lang.reflect.Method.invoke(Method.java:498)
在 org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethodWithGivenArgs(AbstractAspectJAdvice.java:644)
在 org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethod(AbstractAspectJAdvice.java:633)
在 org.springframework.aop.aspectj.AspectJAroundAdvice.invoke(AspectJAroundAdvice.java:70)
在 org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
在 org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:92)
在 org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:185)
在 org.springframework.aop.framework.CglibAopProxy$DynamicAdvisedInterceptor.intercept(CglibAopProxy.java:688)
在 com.tuhu.predict.api.controller.MlpFindPageController$$EnhancerBySpringCGLIB$$560ed775.recommendPredict()
在 sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method)
在 sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
在 sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
在 java.lang.reflect.Method.invoke(Method.java:498)
在 org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:209)
在 org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:136)
在 org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:102)
在 org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:877)
在 org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:783)
在 org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
在 org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:991)
在 org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:925)
在 org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:974)
在 org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:877)
在 javax.servlet.http.HttpServlet.service(HttpServlet.java:661)
在 org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:851)
在 javax.servlet.http.HttpServlet.service(HttpServlet.java:742)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:52)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 com.tuhu.soter.starter.filter.SoterDefaultFilter.doFilter(SoterDefaultFilter.java:79)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 com.tuhu.boot.logback.filter.LogFilter.doFilter(LogFilter.java:54)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:158)
在 org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:126)
在 org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.doFilterInternal(WebMvcMetricsFilter.java:111)
在 org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 org.springframework.boot.actuate.web.trace.servlet.HttpTraceFilter.doFilterInternal(HttpTraceFilter.java:90)
在 org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 com.tuhu.boot.common.filter.HeartbeatFilter.doFilter(HeartbeatFilter.java:42)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 com.tuhu.boot.common.filter.MDCFilter.doFilter(MDCFilter.java:47)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:99)
在 org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 org.springframework.web.filter.HttpPutFormContentFilter.doFilterInternal(HttpPutFormContentFilter.java:109)
在 org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 org.springframework.web.filter.HiddenHttpMethodFilter.doFilterInternal(HiddenHttpMethodFilter.java:93)
在 org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:200)
在 org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
在 org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
在 org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
在 org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:198)
在 org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:96)
在 org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:496)
在 org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:140)
在 org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:81)
在 org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:87)
在 org.apache.catalina.valves.RemoteIpValve.invoke(RemoteIpValve.java:677)
在 org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:342)
在 org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:803)
在 org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:66)
在 org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:790)
在 org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1468)
在 org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)
在 java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
在 java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
在 org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
在 java.lang.Thread.run(Thread.java:748)

最有用的评论

好吧,我想我会在明天之前把它准备好。

所有37条评论

因为模型是通过 Python Sklearn 训练的,所以后面会出现不兼容的情况。为了节省时间,算法团队将 Sklearn 训练的 XGB 模型在 Python XgBoost 包上移了一层。不知道是不是这个原因

image

您使用的是哪个版本的 XGBoost? 之前我们修复了预测失败时 jvm 包没有正确抛出异常并继续使用空预测缓冲区的错误。

您使用的是哪个版本的 XGBoost? 之前我们修复了预测失败时 jvm 包没有正确抛出异常并继续使用空预测缓冲区的错误。

使用的是公司算法平台的1.0版本,由于版本兼容性问题,使用了算法项目的0.9.0版本。算法同事使用Python将1.0的模型文件转换为0.9.0。 不知道是不是这个变形造成的

我建议等待 1.2 (https://github.com/dmlc/xgboost/issues/5734) 再试一次,我们在这个版本中有一些重要的错误修复。 此外,我建议使用相同或更高版本的 xgboost 进行预测。 XGBoost 的二进制模型向后兼容,向前移动,推荐基于 JSON 的模型。

我在 1.2.0 中遇到了同样的问题。 所以问题还是在这里。

我也遇到了同样的问题。
我使用 xgboost4j 创建模型。

有解决方法吗?

这对我来说是一个大问题,它在生产中失败了。

@ranInc你使用的是最新版本的 XGBoost 吗? 到目前为止,我们还不知道这个问题的确切原因。 我们将尽最大努力解决它,由于无法保证何时可以解决该问题,我建议您同时调查替代方案。

@ranInc您可以通过提供一个我们(开发人员)可以在我们自己的机器上运行的小示例程序来帮助我们。

我正在运行 1.2.0,这是 maven 存储库上的最新 jar。
我的替代方案是回到 saprk 2.4.5 而不是使用 xgboost 0.9 - 这就是我现在正在做的事情。

例如:我将尝试找出导致作业失败的特定模型/数据。

你好,
我找到了具体的模型/数据。
我将其附在此评论中。
xgb1.2_bug.zip

这就是您重新创建错误的方式(请记住,如果您不在这里进行重新分区,它会起作用 - 因此它与每个分区中的数据量或数据类型有关):

from pyspark.ml.pipeline import PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
df.count()

model = PipelineModel.read().load("/tmp/6620294785024229130_model_xg_only")

predictions = model.transform(df)

predictions.persist()
predictions.count()
predictions.show()

你知道什么时候可以解决这个问题吗?
这使我无法使用 spark 3.0 ...

@ranInc还没有。 我们会在修复错误时通知您。 另外,您可以在 Scala 中发布代码吗? 我认为我们从未正式支持将 PySpark 与 XGBoost 一起使用。

      import org.apache.spark.ml.{Pipeline, PipelineModel}
      val df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
      df.count()

      val model = PipelineModel.read.load("/tmp/6620294785024229130_model_xg_only")

      val predictions = model.transform(df)

      predictions.persist()
      predictions.count()
      predictions.show()

另一个指针,
似乎问题是由于发送到要预测的所有特征都是零/缺失。

我想没有人在研究这个吗?
这基本上意味着 xgboost 根本不适用于 spark 3。

是的,很抱歉,我们的手现在很忙。 我们会在某个时候解决这个问题。 我恭敬地请求你的耐心。 谢谢。

@ranInc我今天有一些时间,所以我尝试运行您在此处提供的脚本。 我已经重现了java.lang.NullPointerException错误。

奇怪的是,最新的开发版本( master分支)不会以同样的方式崩溃。 相反,它会产生错误

线程“主”org.apache.spark.SparkException 中的异常:作业因阶段故障而中止:阶段 7.0 中的任务 0 失败 1 次,最近一次失败:阶段 7.0 中丢失任务 0.0(TID 11,d04389c5babb,执行程序驱动程序): ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: 检查失败:learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) :列数与增强器中的特征数不匹配。

我会进一步调查。

我认为错误消息现在很有意义,您的输入比预测模型具有更多功能。

之前jvm包会在xgboost失败后继续执行,导致预测缓冲区为空。 我最近添加了一个检查守卫。

只需确保您的训练数据集中的列数大于或等于您的预测数据集。

你好,
该模型是使用相同数量的特征创建的。
在 spark 中,它使用一个向量列而不是多列。
在任何情况下,向量的大小总是相同的,用于拟合和预测 - 100% 肯定。

这与所有零/缺失特征的行有关。
您可以看到,如果您从数据框中过滤所有零特征的行 - 它工作得很好。

@ranInc您可以发布生成模型的完整 Scala 程序吗? 该错误消息似乎表明您的模型是使用单一特征训练的。

我认为这不会有太大帮助,因为代码非常通用并且有一些支持转换器,
代码本身主要是 pyspark 而不是 scala。

看到特征数量不是问题的最好方法就是过滤掉所有零特征的行并使用模型——这没有问题。
您还可以保留所有行并将数据帧重新分区以使用一个分区,这也可以。

@ranInc我过滤掉了零的行,但仍然面临同样的错误( java.lang.NullPointerException ):

...
df.na.drop(minNonNulls = 1)
...

这不是正确的方法吗?

我认为这不会有太大帮助,因为代码非常通用并且有一些支持转换器

我想看看在训练和预测时使用了多少特征。 错误信息

ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: 检查失败:learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) :列数与增强器中的特征数不匹配。

表明该模型是用一个特征训练的,而预测是用两个特征进行的。

现在,我只能访问您上传的数据框和序列化模型。 我对模型训练的内容和问题缺乏洞察力,这阻碍了我进一步解决问题。 如果你的程序有一些专有信息,是否有可能产生一个干净的例子?

  1. 不,你可以这样做:
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{callUDF, col}
.......
 val isVectorAllZeros = (col: Vector) => {
        col match {
          case sparse: SparseVector =>
            if(sparse.indices.isEmpty){
              true
            }else{
              false
            }
          case _ => false
        }
      }

      spark.udf.register("isVectorAllZeros", isVectorAllZeros)
      df = df.withColumn("isEmpty",callUDF("isVectorAllZeros",
        col("features_6620294785024229130"))).where("isEmpty == false")

您也可以像这样重新分区数据框:

....
df = df.repartition(1)
....
  1. 我明白了,但是代码不会给你太多,因为它使用了 VectorAssembler,你将无法知道实际使用了多少功能,
    但我 100% 确定它使用了相同数量的功能。

但我 100% 确定它使用了相同数量的功能。

如果 VectorAssembler 导致具有可变数量的特征,您如何确保这一点?

VectorAssembler 总是创建相同数量的特征,它只需要获取列的名称。
代码本身是用来创建数千个模型的,所以它非常通用,基本上得到了一个要使用的名称列表。

我可能能够再次运行模型创建并向您发送用于模型的数据框 - 或您需要的任何其他数据。
不过,这对我来说需要时间,如果您使用我之前展示的内容,您会发现该模型具有 2 个功能,效果很好。

@ranInc让我再问一个问题:说示例数据有一个稀疏列(VectorAssembler)最多有两个特征是否正确?

不。
VectorAssembler 是一个 Trasformer,它抓取多个列并将它们放在一个 Vector 列中。
向量总是用于在 spark 中拟合和预测模型。

这里的示例数据框有一个向量列。
有些行稀疏,有些行密集-都有两个特征。

@ranInc所以所有行都有两个特征,一些值缺失,而另一些则没有。 知道了。 我会尝试你关于过滤空行的建议。

正如您可能已经猜到的那样,我对 Spark 生态系统还很陌生,因此调试工作可能会非常困难。 我们目前需要更多了解 Spark 和 Scala 编程的开发人员。 如果您个人知道有人愿意帮助我们改进 XGBoost 的 JVM 包,请告诉我们。

@ranInc我尝试根据您的建议过滤空行:

程序 A:示例脚本,不过滤空行

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

程序 B:空行过滤示例

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val isVectorAllZeros = (col: Vector) => {
    col match {
      case sparse: SparseVector => (sparse.indices.isEmpty)
      case _ => false
    }
  }
  spark.udf.register("isVectorAllZeros", isVectorAllZeros)

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
                .withColumn("isEmpty", callUDF("isVectorAllZeros", col("features_6620294785024229130")))
                .where("isEmpty == false")
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

一些观察

  • 在稳定的 1.2.0 版本中,程序 A出现java.lang.NullPointerException错误。 在 NPE 之前,Spark 执行日志中会显示以下警告:
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1
  • 使用稳定的 1.2.0 版本,程序 B成功完成且没有错误。
  • 使用开发版本(最新的master分支,提交 42d31d9dcb6f7c1cb7d0545e9ab3a305ecad0816),程序 A程序 B都失败并出现以下错误:
[12:44:57] /home/ubuntu/xgblatest/src/learner.cc:1179: Check failed: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster.                                                                                                        
Stack trace:                                                                                                                                   
  [bt] (0) /tmp/libxgboost4j14081654332866852928.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x79) [0x7f7ef62c4e19]                             [bt] (1) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::ValidateDMatrix(xgboost::DMatrix*, bool) const+0x20b) [0x7f7ef63f5f0b]                                                                                                                                              
  [bt] (2) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::Predict(std::shared_ptr<xgboost::DMatrix>, bool, xgboost::HostDeviceVector<float>*, unsigned int, bool, bool, bool, bool, bool)+0x3c3) [0x7f7ef6400233]                                                             
  [bt] (3) /tmp/libxgboost4j14081654332866852928.so(XGBoosterPredict+0xec) [0x7f7ef62caa3c]                                                      [bt] (4) /tmp/libxgboost4j14081654332866852928.so(Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict+0x47) [0x7f7ef62befd7]             
  [bt] (5) [0x7f80908a8270]

这很奇怪,因为根据@ranInc的说法,该模型是使用具有两个特征的数据进行训练的。

  • 我从源代码构建了版本1.2.0-SNAPSHOT (提交 71197d1dfa27c80add9954b10284848c1f165c40)。 这一次,程序 A程序 B都因功能不匹配错误 ( learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster ) 而失败。
  • 稳定的 1.2.0 版本和1.2.0-SNAPSHOT之间的行为差​​异是出乎意料的,让我很紧张。 特别是来自 1.2.0 的警告消息
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1

在 1.2.0 版本的 C++ 代码库中找不到。 相反,警告出现在release_1.0.0分支中:
https://github.com/dmlc/xgboost/blob/ea6b117a5737f5beb2533fc89b3f3fcd72ecc04e/src/learner.cc#L972 -L982
那么这是否意味着 Maven Central 上的 1.2.0 JAR 文件具有来自 1.0.0 的libxgboost4j.so ? 🤯😱

  • 实际上,来自 Maven Central 的 1.2.0 JAR 文件包含libxgboost4j.so ,实际上是 1.0.0(!!!)。 要找出答案,请从Maven Central下载xgboost4j_2.12-1.2.0.jar并提取libxgboost4j.so文件。 然后运行以下 Python 脚本来验证库文件的版本:
import ctypes

lib = ctypes.cdll.LoadLibrary('./libxgboost4j.so')

major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()

lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
print((major.value, minor.value, patch.value))  # prints (1, 0, 2), indicating version 1.0.2
  • 撇开 1.0.0 的问题不谈,我们清楚地看到经过训练的 XGBoost 模型只识别了一个特征( learner_model_param_.num_feature == 1 )。 也许训练数据有一个 100% 空的特征? @ranInc

您要我获取用于创建模型的数据框吗?
如果我能够抓住它,我想我可以创建一个简单的 scala 代码来创建模型。

@ranInc我怀疑训练数据中的两个特征之一完全由缺失值组成,将learner_model_param_.num_feature设置为 1。所以是的,查看训练数据将非常有帮助。

好吧,我想我会在明天之前把它准备好。

创建 #6426 以跟踪不匹配libxgboost4j.so的问题。 在这里 (#5957) 让我们继续讨论为什么learner_model_param_.num_feature被设置为 1。

看来你错了,训练数据没有缺失值。
在此处的示例代码中,我没有使用重新分区来重现故障,而是仅使用一行(仅具有零特征)进行预测。

features_creation.zip

import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.DataFrame

val df = spark.read.parquet("/tmp/6620294785024229130_only_features_creation").persist()
df.count()

val regressor = new XGBoostRegressor()
    .setFeaturesCol("features_6620294785024229130")
    .setLabelCol("label_6620294785024229130")
    .setPredictionCol("prediction")
    .setMissing(0.0F)
    .setMaxDepth(3)
    .setNumRound(100)
    .setNumWorkers(1)

val pipeline = new Pipeline().setStages(Array(regressor))
val model = pipeline.fit(df)

val pred = spark.read.parquet("/tmp/6620294785024229130_features").persist()
pred.count()
pred.where("account_code == 4011593987").show()
model.transform(pred.where("account_code == 4011593987")).show()
此页面是否有帮助?
0 / 5 - 0 等级