Xgboost: [jvm-paquetes] java.lang.NullPointerException: nulo en ml.dmlc.xgboost4j.java.Booster.predict

Creado en 30 jul. 2020  ·  37Comentarios  ·  Fuente: dmlc/xgboost

Las excepciones de NPE ocurren cuando se predicen a través de la API de JAVA.

java.lang.NullPointerException: nulo
en ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:309)
en ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:375)
en com.tuhu.predict.predict.BaseModelPredict.predict(BaseModelPredict.java:71)
en com.tuhu.predict.predict.XgboostFindPageModelPredict.predict(XgboostFindPageModelPredict.java:53)
en com.tuhu.predict.service.impl.MlpFindPageFeatureServiceImpl.featureProcess(MlpFindPageFeatureServiceImpl.java:65)
en com.tuhu.predict.api.controller.MlpFindPageController.recommendPredict(MlpFindPageController.java:49)
en com.tuhu.predict.api.controller.MlpFindPageController$$FastClassBySpringCGLIB$$f694b9ff.invoke()
en org.springframework.cglib.proxy.MethodProxy.invoke(MethodProxy.java:204)
en org.springframework.aop.framework.CglibAopProxy$CglibMethodInvocation.invokeJoinpoint(CglibAopProxy.java:746)
en org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)
en org.springframework.aop.framework.adapter.MethodBeforeAdviceInterceptor.invoke(MethodBeforeAdviceInterceptor.java:52)
en org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
en org.springframework.aop.aspectj.AspectJAfterAdvice.invoke(AspectJAfterAdvice.java:47)
en org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
en org.springframework.aop.framework.adapter.AfterReturningAdviceInterceptor.invoke(AfterReturningAdviceInterceptor.java:52)
en org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
en org.springframework.aop.aspectj.AspectJAfterThrowingAdvice.invoke(AspectJAfterThrowingAdvice.java:62)
en org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
en org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint.proceed(MethodInvocationProceedingJoinPoint.java:88)
en com.tuhu.springcloud.common.annotation.AbstractControllerLogAspect.doAround(AbstractControllerLogAspect.java:104)
en sun.reflect.NativeMethodAccessorImpl.invoke0(Método nativo)
en sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
en sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
en java.lang.reflect.Method.invoke(Method.java:498)
en org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethodWithGivenArgs(AbstractAspectJAdvice.java:644)
en org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethod(AbstractAspectJAdvice.java:633)
en org.springframework.aop.aspectj.AspectJAroundAdvice.invoke(AspectJAroundAdvice.java:70)
en org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
en org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:92)
en org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:185)
en org.springframework.aop.framework.CglibAopProxy$DynamicAdvisedInterceptor.intercept(CglibAopProxy.java:688)
en com.tuhu.predict.api.controller.MlpFindPageController$$EnhancerBySpringCGLIB$$560ed775.recommendPredict()
en sun.reflect.NativeMethodAccessorImpl.invoke0(Método nativo)
en sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
en sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
en java.lang.reflect.Method.invoke(Method.java:498)
en org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:209)
en org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:136)
en org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:102)
en org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:877)
en org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:783)
en org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
en org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:991)
en org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:925)
en org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:974)
en org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:877)
en javax.servlet.http.HttpServlet.service(HttpServlet.java:661)
en org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:851)
en javax.servlet.http.HttpServlet.service(HttpServlet.java:742)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:52)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en com.tuhu.soter.starter.filter.SoterDefaultFilter.doFilter(SoterDefaultFilter.java:79)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en com.tuhu.boot.logback.filter.LogFilter.doFilter(LogFilter.java:54)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:158)
en org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:126)
en org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.doFilterInternal(WebMvcMetricsFilter.java:111)
en org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en org.springframework.boot.actuate.web.trace.servlet.HttpTraceFilter.doFilterInternal(HttpTraceFilter.java:90)
en org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en com.tuhu.boot.common.filter.HeartbeatFilter.doFilter(HeartbeatFilter.java:42)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en com.tuhu.boot.common.filter.MDCFilter.doFilter(MDCFilter.java:47)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:99)
en org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en org.springframework.web.filter.HttpPutFormContentFilter.doFilterInternal(HttpPutFormContentFilter.java:109)
en org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en org.springframework.web.filter.HiddenHttpMethodFilter.doFilterInternal(HiddenHttpMethodFilter.java:93)
en org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:200)
en org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
en org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)
en org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
en org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:198)
en org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:96)
en org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:496)
en org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:140)
en org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:81)
en org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:87)
en org.apache.catalina.valves.RemoteIpValve.invoke(RemoteIpValve.java:677)
en org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:342)
en org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:803)
en org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:66)
en org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:790)
en org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1468)
en org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)
en java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
en java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
en org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
en java.lang.Thread.run(Thread.java:748)

Comentario más útil

Muy bien, creo que lo tendré listo para mañana.

Todos 37 comentarios

Debido a que el modelo se entrena a través de Python Sklearn, luego ocurren incompatibilidades. Para ahorrar tiempo, el equipo de algoritmos movió el modelo XGB entrenado por Sklearn una capa sobre el paquete Python XgBoost. Me pregunto si eso es lo que lo causó.

image

¿Qué versión de XGBoost estás usando? Anteriormente, solucionamos un error que indicaba que el paquete jvm no arrojaba una excepción correctamente cuando fallaba la predicción y continuaba con un búfer de predicción vacío.

¿Qué versión de XGBoost estás usando? Anteriormente, solucionamos un error que indicaba que el paquete jvm no arrojaba una excepción correctamente cuando fallaba la predicción y continuaba con un búfer de predicción vacío.

Se usa la versión 1.0 de la plataforma de algoritmos de la compañía y la versión 0.9.0 del proyecto de algoritmo debido a problemas de compatibilidad de versiones. Los colegas del algoritmo usaron Python para convertir el archivo del modelo 1.0 a 0.9.0. Me pregunto si es causado por esta transformación.

Sugeriría esperar a 1.2 (https://github.com/dmlc/xgboost/issues/5734) e intentarlo de nuevo, tenemos algunas correcciones de errores importantes en esta versión. También sugeriría usar la misma versión xgboost o una posterior para la predicción. El modelo binario de XGBoost es compatible con versiones anteriores; en adelante, se recomienda el modelo basado en JSON.

Me encontré con el mismo problema con 1.2.0. Así que el problema sigue aquí.

También tengo el mismo problema.
Usé xgboost4j para crear el modelo.

¿Hay una solución?

Este es un gran problema para mí, falló trabajos en producción.

@ranInc ¿Está utilizando la última versión de XGBoost? Hasta el momento no conocemos la causa exacta de este problema. Lo abordaremos en la medida de lo posible y, dado que no hay garantía de cuándo podría solucionarse el problema, le sugiero que mientras tanto investigue una alternativa.

@ranInc Puede ayudarnos proporcionando un pequeño programa de ejemplo que nosotros (los desarrolladores) podemos ejecutar en nuestra propia máquina.

Estoy ejecutando 1.2.0, el último jar en el repositorio maven.
La alternativa para mí es volver a saprk 2.4.5 y luego usar xgboost 0.9, y esto es lo que estoy haciendo ahora.

Para el ejemplo: intentaré identificar el modelo/datos específicos que causan que el trabajo falle más tarde.

Hola,
Encontré el modelo/datos específicos.
Lo adjunto en este comentario.
xgb1.2_bug.zip

así es como recrea el error (tenga en cuenta que si no hace la partición aquí, funciona, por lo que tiene algo que ver con la cantidad de datos o el tipo de datos en cada partición):

from pyspark.ml.pipeline import PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
df.count()

model = PipelineModel.read().load("/tmp/6620294785024229130_model_xg_only")

predictions = model.transform(df)

predictions.persist()
predictions.count()
predictions.show()

¿Tienes alguna idea de cuándo se puede abordar esto?
Esto me impide usar chispa 3.0...

@ranInc Todavía no. Te avisaremos cuando solucionemos el error. Además, ¿puedes publicar el código en Scala? Creo que nunca apoyamos oficialmente el uso de PySpark con XGBoost.

      import org.apache.spark.ml.{Pipeline, PipelineModel}
      val df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
      df.count()

      val model = PipelineModel.read.load("/tmp/6620294785024229130_model_xg_only")

      val predictions = model.transform(df)

      predictions.persist()
      predictions.count()
      predictions.show()

Otro puntero,
parece que el problema se debe a que todas las características que se envían para predecir son ceros/faltan.

¿Supongo que nadie está trabajando en esto?
Básicamente, esto significa que xgboost no funciona en Spark 3 en absoluto.

Sí, lo siento, nuestras manos están bastante ocupadas en este momento. Nos ocuparemos de este tema en algún momento. Respetuosamente les pido paciencia. Gracias.

@ranInc Tuve algo de tiempo hoy, así que intenté ejecutar el script que proporcionó aquí. He reproducido el error java.lang.NullPointerException .

Extrañamente, la última versión de desarrollo (rama master ) no falla de la misma manera. En cambio, produce error.

Excepción en el subproceso "principal" org.apache.spark.SparkException: trabajo cancelado debido a un error de etapa: la tarea 0 en la etapa 7.0 falló 1 vez, falla más reciente: tarea perdida 0.0 en la etapa 7.0 (TID 11, d04389c5babb, controlador ejecutor): ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: Falló la comprobación: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 frente a 2) : El número de columnas no coincide con el número de funciones en el refuerzo.

Investigaré más a fondo.

Creo que el mensaje de error tiene sentido ahora, su entrada tiene más características que el modelo de predicción.

Antes de que el paquete jvm continúe después de la falla de xgboost, lo que resultará en un búfer de predicción vacío. Agregué un protector de cheques recientemente.

Solo asegúrese de que la cantidad de columnas en su conjunto de datos de entrenamiento sea mayor o igual que su conjunto de datos de predicción.

Hola,
El modelo fue creado usando la misma cantidad de características.
En Spark, utiliza una columna vectorial y no varias columnas.
En cualquier caso, el tamaño del vector es siempre el mismo, para ajuste y predicción, 100% seguro de eso.

Esto tiene algo que ver con las filas con todas las características cero/faltantes.
Puede ver que si filtra desde el marco de datos las filas con todas las características cero, funciona bien.

@ranInc ¿Puede publicar el programa Scala completo que generó el modelo? El mensaje de error parece sugerir que su modelo fue entrenado con una sola función.

No creo que ayude mucho ya que el código es muy genérico y tiene algunos transformadores propiciatorios,
El código en sí es principalmente pyspark y no scala.

La mejor manera de ver que la cantidad de funciones no es un problema es simplemente filtrar las filas con todas las funciones cero y usar el modelo; esto funciona sin problemas.
También puede mantener todas las filas y volver a particionar el marco de datos para usar una partición, y eso también funciona.

@ranInc Filtré filas con cero y sigo enfrentando el mismo error ( java.lang.NullPointerException ):

...
df.na.drop(minNonNulls = 1)
...

¿No es esta la forma correcta de hacerlo?

No creo que ayude mucho ya que el código es muy genérico y tiene algunos transformadores propiciatorios.

Quiero ver cuántas funciones se utilizan en el entrenamiento y en el momento de la predicción. el mensaje de error

ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: Falló la comprobación: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 frente a 2) : El número de columnas no coincide con el número de funciones en el refuerzo.

sugiere que el modelo se entrenó con una sola característica y la predicción se realiza con dos características.

En este momento, solo tengo acceso al marco de datos y al modelo serializado que subiste. Me falta información sobre lo que sucedió en el entrenamiento del modelo y lo que salió mal, lo que me impide seguir solucionando el problema. Si su programa tiene alguna información propietaria, ¿es posible producir un ejemplo limpio?

  1. no, puedes hacer esto:
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{callUDF, col}
.......
 val isVectorAllZeros = (col: Vector) => {
        col match {
          case sparse: SparseVector =>
            if(sparse.indices.isEmpty){
              true
            }else{
              false
            }
          case _ => false
        }
      }

      spark.udf.register("isVectorAllZeros", isVectorAllZeros)
      df = df.withColumn("isEmpty",callUDF("isVectorAllZeros",
        col("features_6620294785024229130"))).where("isEmpty == false")

también puede simplemente volver a particionar el marco de datos de esta manera:

....
df = df.repartition(1)
....
  1. Lo entiendo, pero el código no le dará mucho, porque usa VectorAssembler, y no podrá saber cuántas funciones se usaron realmente,
    Pero estoy 100% seguro de que usó la misma cantidad de funciones.

Pero estoy 100% seguro de que usó la misma cantidad de funciones.

¿Cómo aseguró esto, si VectorAssembler hace que tenga un número variable de características?

VectorAssembler siempre crea la misma cantidad de características, solo necesita los nombres de las columnas para tomar.
El código en sí se usa para crear Miles de modelos, por lo que es muy genérico y básicamente obtiene una lista de nombres para usar.

Es posible que pueda volver a ejecutar la creación del modelo y enviarle el marco de datos utilizado para el modelo, o cualquier otro dato que necesite.
Sin embargo, eso me llevará tiempo y si usa lo que mostré antes, verá que el modelo funciona bien con 2 características.

@ranInc Permítanme hacer una pregunta más: ¿es correcto decir que los datos de ejemplo tienen una columna dispersa (VectorAssembler) que tiene como máximo dos características ?

No.
VectorAssembler es un Transformador que toma múltiples columnas y las coloca en una columna Vector.
Los vectores siempre se utilizan para el ajuste y la predicción de modelos en chispa.

El marco de datos de ejemplo aquí tiene una columna de vector.
Algunas filas son escasas, otras densas; todas tienen dos características.

@ranInc Entonces, todas las filas tienen dos características, faltan algunos valores y otros no. Entiendo. Intentaré su sugerencia sobre filtrar filas vacías.

Como habrás adivinado, soy bastante nuevo en el ecosistema Spark, por lo que el esfuerzo de depuración puede resultar bastante difícil. Actualmente necesitamos más desarrolladores que sepan más sobre la programación de Spark y Scala en general. Si conoce personalmente a alguien a quien le gustaría ayudarnos a mejorar el paquete JVM de XGBoost, háganoslo saber.

@ranInc Intenté filtrar filas vacías según su sugerencia:

Programa A: script de ejemplo, sin filtrar por filas vacías

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

Programa B: ejemplo con filtrado de filas vacías

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val isVectorAllZeros = (col: Vector) => {
    col match {
      case sparse: SparseVector => (sparse.indices.isEmpty)
      case _ => false
    }
  }
  spark.udf.register("isVectorAllZeros", isVectorAllZeros)

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
                .withColumn("isEmpty", callUDF("isVectorAllZeros", col("features_6620294785024229130")))
                .where("isEmpty == false")
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

Algunas observaciones

  • Con la versión estable 1.2.0, el Programa A genera un error con java.lang.NullPointerException . Justo antes de la NPE, se muestra la siguiente advertencia en el registro de ejecución de Spark:
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1
  • Con la versión estable 1.2.0, el programa B se completa con éxito y sin errores.
  • Con la versión de desarrollo (última rama master , confirmación 42d31d9dcb6f7c1cb7d0545e9ab3a305ecad0816), tanto el Programa A como el Programa B fallan con el siguiente error:
[12:44:57] /home/ubuntu/xgblatest/src/learner.cc:1179: Check failed: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster.                                                                                                        
Stack trace:                                                                                                                                   
  [bt] (0) /tmp/libxgboost4j14081654332866852928.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x79) [0x7f7ef62c4e19]                             [bt] (1) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::ValidateDMatrix(xgboost::DMatrix*, bool) const+0x20b) [0x7f7ef63f5f0b]                                                                                                                                              
  [bt] (2) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::Predict(std::shared_ptr<xgboost::DMatrix>, bool, xgboost::HostDeviceVector<float>*, unsigned int, bool, bool, bool, bool, bool)+0x3c3) [0x7f7ef6400233]                                                             
  [bt] (3) /tmp/libxgboost4j14081654332866852928.so(XGBoosterPredict+0xec) [0x7f7ef62caa3c]                                                      [bt] (4) /tmp/libxgboost4j14081654332866852928.so(Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict+0x47) [0x7f7ef62befd7]             
  [bt] (5) [0x7f80908a8270]

lo cual es extraño porque, según @ranInc , el modelo se entrenó con datos con dos características.

  • Creé la versión 1.2.0-SNAPSHOT desde la fuente (confirmar 71197d1dfa27c80add9954b10284848c1f165c40). Esta vez, tanto el Programa A como el Programa B fallan con el error de falta de coincidencia de características ( learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster ).
  • La diferencia de comportamiento entre la versión estable 1.2.0 y 1.2.0-SNAPSHOT fue inesperada y me puso bastante nervioso. En particular, el mensaje de advertencia de 1.2.0
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1

no se encuentra en la versión 1.2.0 del código base de C++. En cambio, la advertencia se encuentra en la rama release_1.0.0 :
https://github.com/dmlc/xgboost/blob/ea6b117a5737f5beb2533fc89b3f3fcd72ecc04e/src/learner.cc#L972 -L982
Entonces, ¿significa que el archivo JAR 1.2.0 en Maven Central tiene libxgboost4j.so de 1.0.0? 🤯 😱

  • De hecho, el archivo JAR 1.2.0 de Maven Central contiene libxgboost4j.so que en realidad es 1.0.0 (!!!). Para averiguarlo, descargue xgboost4j_2.12-1.2.0.jar de Maven Central y extraiga el archivo libxgboost4j.so . Luego ejecute el siguiente script de Python para verificar la versión del archivo de la biblioteca:
import ctypes

lib = ctypes.cdll.LoadLibrary('./libxgboost4j.so')

major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()

lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
print((major.value, minor.value, patch.value))  # prints (1, 0, 2), indicating version 1.0.2
  • Dejando a un lado el problema de 1.0.0, vemos claramente que el modelo XGBoost entrenado reconoció solo una característica única ( learner_model_param_.num_feature == 1 ). ¿Quizás los datos de entrenamiento tenían una característica que estaba 100% vacía? @ranInc

¿Quieres que tome el marco de datos utilizado para crear el modelo?
Si puedo agarrarlo, creo que puedo crear un código Scala simple que crea el modelo.

@ranInc Mi sospecha es que una de las dos características en los datos de entrenamiento consistía completamente en valores faltantes, estableciendo learner_model_param_.num_feature en 1. Entonces, sí, ver los datos de entrenamiento será muy útil.

Muy bien, creo que lo tendré listo para mañana.

Creado #6426 para realizar un seguimiento del problema de libxgboost4j.so no coincidentes. Aquí (#5957) mantengamos la discusión sobre por qué learner_model_param_.num_feature se establece en 1.

Parece que estás equivocado, los datos de entrenamiento no tienen valores faltantes.
en el código de ejemplo aquí, en lugar de confiar en la partición para reproducir la falla, en su lugar usé solo una fila (que solo tiene cero características) para la predicción.

características_creación.zip

import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.DataFrame

val df = spark.read.parquet("/tmp/6620294785024229130_only_features_creation").persist()
df.count()

val regressor = new XGBoostRegressor()
    .setFeaturesCol("features_6620294785024229130")
    .setLabelCol("label_6620294785024229130")
    .setPredictionCol("prediction")
    .setMissing(0.0F)
    .setMaxDepth(3)
    .setNumRound(100)
    .setNumWorkers(1)

val pipeline = new Pipeline().setStages(Array(regressor))
val model = pipeline.fit(df)

val pred = spark.read.parquet("/tmp/6620294785024229130_features").persist()
pred.count()
pred.where("account_code == 4011593987").show()
model.transform(pred.where("account_code == 4011593987")).show()
¿Fue útil esta página
0 / 5 - 0 calificaciones