Xgboost: [jvm-packages] java.lang.NullPointerException: null di ml.dmlc.xgboost4j.java.Booster.predict

Dibuat pada 30 Jul 2020  ·  37Komentar  ·  Sumber: dmlc/xgboost

Pengecualian NPE terjadi saat diprediksi melalui JAVA API.

java.lang.NullPointerException: null
di ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:309)
di ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:375)
di com.tuhu.predict.predict.BaseModelPredict.predict(BaseModelPredict.java:71)
di com.tuhu.predict.predict.XgboostFindPageModelPredict.predict(XgboostFindPageModelPredict.java:53)
di com.tuhu.predict.service.impl.MlpFindPageFeatureServiceImpl.featureProcess(MlpFindPageFeatureServiceImpl.java:65)
di com.tuhu.predict.api.controller.MlpFindPageController.recommendPredict(MlpFindPageController.java:49)
di com.tuhu.predict.api.controller.MlpFindPageController$$FastClassBySpringCGLIB$$f694b9ff.invoke()
di org.springframework.cglib.proxy.MethodProxy.invoke(MethodProxy.java:204)
di org.springframework.aop.framework.CglibAopProxy$CglibMethodInvocation.invokeJoinpoint(CglibAopProxy.java:746)
di org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)
di org.springframework.aop.framework.adapter.MethodBeforeAdviceInterceptor.invoke(MethodBeforeAdviceInterceptor.java:52)
di org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
di org.springframework.aop.aspectj.AspectJAfterAdvice.invoke(AspectJAfterAdvice.java:47)
di org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
di org.springframework.aop.framework.adapter.AfterReturningAdviceInterceptor.invoke(AfterReturningAdviceInterceptor.java:52)
di org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
di org.springframework.aop.aspectj.AspectJAfterThrowingAdvice.invoke(AspectJAfterThrowingAdvice.java:62)
di org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
di org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint.proceed(MethodInvocationProceedingJoinPoint.java:88)
di com.tuhu.springcloud.common.annotation.AbstractControllerLogAspect.doAround(AbstractControllerLogAspect.java:104)
di sun.reflect.NativeMethodAccessorImpl.invoke0(Metode Asli)
di sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
di sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
di java.lang.reflect.Method.invoke(Method.java:498)
di org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethodWithGivenArgs(AbstractAspectJAdvice.java:644)
di org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethod(AbstractAspectJAdvice.java:633)
di org.springframework.aop.aspectj.AspectJAroundAdvice.invoke(AspectJAroundAdvice.java:70)
di org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)
di org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:92)
di org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:185)
di org.springframework.aop.framework.CglibAopProxy$DynamicAdvisedInterceptor.intercept(CglibAopProxy.java:688)
di com.tuhu.predict.api.controller.MlpFindPageController$$EnhancerBySpringCGLIB$$560ed775.recommendPredict()
di sun.reflect.NativeMethodAccessorImpl.invoke0(Metode Asli)
di sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)
di sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)
di java.lang.reflect.Method.invoke(Method.java:498)
di org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:209)
di org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:136)
di org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:102)
di org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:877)
di org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:783)
di org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)
di org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:991)
di org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:925)
di org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:974)
di org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:877)
di javax.servlet.http.HttpServlet.service(HttpServlet.java:661)
di org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:851)
di javax.servlet.http.HttpServlet.service(HttpServlet.java:742)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:52)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di com.tuhu.soter.starter.filter.SoterDefaultFilter.doFilter(SoterDefaultFilter.java:79)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di com.tuhu.boot.logback.filter.LogFilter.doFilter(LogFilter.java:54)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:158)
di org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:126)
di org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.doFilterInternal(WebMvcMetricsFilter.java:111)
di org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di org.springframework.boot.actuate.web.trace.servlet.HttpTraceFilter.doFilterInternal(HttpTraceFilter.java:90)
di org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di com.tuhu.boot.common.filter.HeartbeatFilter.doFilter(HeartbeatFilter.java:42)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di com.tuhu.boot.common.filter.MDCFilter.doFilter(MDCFilter.java:47)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:99)
di org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di org.springframework.web.filter.HttpPutFormContentFilter.doFilterInternal(HttpPutFormContentFilter.java:109)
di org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di org.springframework.web.filter.HiddenHttpMethodFilter.doFilterInternal(HiddenHttpMethodFilter.java:93)
di org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:200)
di org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)
di org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java: 193)
di org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)
di org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:198)
di org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:96)
di org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:496)
di org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java: 140)
di org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:81)
di org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:87)
di org.apache.catalina.valves.RemoteIpValve.invoke(RemoteIpValve.java:677)
di org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:342)
di org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:803)
di org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:66)
di org.apache.coyote.AbstractProtocol$ConnectionHandler.process(AbstractProtocol.java:790)
di org.apache.tomcat.util.net.NioEndpoint$SocketProcessor.doRun(NioEndpoint.java:1468)
di org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)
di java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
di java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
di org.apache.tomcat.util.threads.TaskThread$WrappingRunnable.run(TaskThread.java:61)
di java.lang.Thread.run(Thread.java:748)

Komentar yang paling membantu

Baiklah, saya pikir saya akan menyiapkannya besok.

Semua 37 komentar

Karena model dilatih melalui Python Sklearn, ketidakcocokan kemudian terjadi. Untuk menghemat waktu, tim algoritme memindahkan model XGB yang dilatih Sklearn satu lapisan di atas paket Python XgBoost. Saya ingin tahu apakah itu penyebabnya

image

Versi XGBoost mana yang Anda gunakan? Sebelumnya kami memperbaiki bug yang paket jvm tidak melempar pengecualian dengan benar saat prediksi gagal dan melanjutkan dengan buffer prediksi kosong.

Versi XGBoost mana yang Anda gunakan? Sebelumnya kami memperbaiki bug yang paket jvm tidak melempar pengecualian dengan benar saat prediksi gagal dan melanjutkan dengan buffer prediksi kosong.

Versi 1.0 dari platform algoritma perusahaan digunakan, dan versi 0.9.0 dari proyek algoritma digunakan karena masalah kompatibilitas versi. Rekan algoritma menggunakan Python untuk mengonversi file model 1.0 ke 0.9.0. Saya ingin tahu apakah itu disebabkan oleh transformasi ini

Saya sarankan menunggu 1.2 (https://github.com/dmlc/xgboost/issues/5734) dan coba lagi, kami memiliki beberapa perbaikan bug penting dalam rilis ini. Saya juga menyarankan menggunakan versi xgboost yang sama atau lebih baru untuk prediksi. Model biner XGBoost kompatibel ke belakang, direkomendasikan untuk bergerak maju, model berbasis JSON.

Saya mengalami masalah yang sama dengan 1.2.0. Jadi masalahnya masih di sini.

Saya juga mendapat masalah yang sama.
Saya menggunakan xgboost4j untuk membuat model.

apakah ada solusi?

Ini adalah masalah besar bagi saya, itu gagal pekerjaan dalam produksi.

@ranInc Apakah Anda menggunakan XGBoost versi terbaru? Sejauh ini kami belum mengetahui penyebab pasti dari masalah ini. Kami akan mengatasinya dengan upaya terbaik, dan karena tidak ada jaminan kapan masalah tersebut dapat diatasi, saya sarankan Anda menyelidiki alternatif sementara itu.

@ranInc Anda dapat membantu kami dengan memberikan contoh kecil program yang dapat kami (pengembang) jalankan di mesin kami sendiri.

Saya menjalankan 1.2.0, toples terbaru di repositori maven.
Alternatif bagi saya adalah kembali ke saprk 2.4.5 dan daripada menggunakan xgboost 0.9 - dan inilah yang saya lakukan sekarang.

Sebagai contoh: Saya akan mencoba dan menentukan model/data tertentu yang menyebabkan pekerjaan gagal nanti.

Hai,
Saya menemukan model/data tertentu.
Saya lampirkan di komentar ini.
xgb1.2_bug.zip

ini adalah cara Anda membuat ulang bug (perlu diingat bahwa jika Anda tidak melakukan partisi ulang di sini, ini berfungsi - jadi ini ada hubungannya dengan jumlah data atau jenis data di setiap partisi):

from pyspark.ml.pipeline import PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
df.count()

model = PipelineModel.read().load("/tmp/6620294785024229130_model_xg_only")

predictions = model.transform(df)

predictions.persist()
predictions.count()
predictions.show()

Apakah Anda tahu kapan ini bisa diatasi?
Ini mencegah saya menggunakan spark 3.0...

@ranInc Belum. Kami akan memberi tahu Anda saat kami telah memperbaiki bug. Juga, dapatkah Anda memposting kode di Scala? Saya rasa kami tidak pernah secara resmi mendukung penggunaan PySpark dengan XGBoost.

      import org.apache.spark.ml.{Pipeline, PipelineModel}
      val df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
      df.count()

      val model = PipelineModel.read.load("/tmp/6620294785024229130_model_xg_only")

      val predictions = model.transform(df)

      predictions.persist()
      predictions.count()
      predictions.show()

penunjuk lainnya,
sepertinya masalahnya adalah karena semua fitur yang dikirim untuk diprediksi adalah nol/hilang.

Saya kira tidak ada yang mengerjakan ini?
Ini pada dasarnya berarti xgboost tidak bekerja pada spark 3 sama sekali.

Ya maaf tangan kami cukup penuh sekarang. Kami akan membahas masalah ini di beberapa titik. Saya dengan hormat meminta kesabaran Anda. Terima kasih.

@ranInc Saya punya waktu hari ini jadi saya mencoba menjalankan skrip yang Anda berikan di sini. Saya telah mereproduksi kesalahan java.lang.NullPointerException .

Anehnya, versi pengembangan terbaru ( cabang master ) tidak mogok dengan cara yang sama. Sebaliknya, itu menghasilkan kesalahan

Pengecualian di utas "main" org.Apache.spark.SparkException: Pekerjaan dibatalkan karena kegagalan tahap: Tugas 0 di tahap 7.0 gagal 1 kali, kegagalan terbaru: Tugas hilang 0.0 di tahap 7.0 (TID 11, d04389c5babb, driver pelaksana): ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: Pemeriksaan gagal: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Jumlah kolom tidak sesuai dengan jumlah fitur di booster.

Saya akan menyelidiki lebih lanjut.

Saya pikir pesan kesalahan masuk akal sekarang, masukan Anda memiliki lebih banyak fitur daripada model untuk prediksi.

Sebelum paket jvm akan berlanjut setelah kegagalan xgboost, menghasilkan buffer prediksi kosong. Saya menambahkan penjaga cek baru-baru ini.

Pastikan saja jumlah kolom dalam set data pelatihan Anda lebih besar dari atau sama dengan set data prediksi Anda.

Hai,
Model dibuat menggunakan jumlah fitur yang sama.
Dalam percikan itu menggunakan satu kolom vektor dan bukan banyak kolom.
Bagaimanapun ukuran vektor selalu sama, untuk penyesuaian dan prediksi - 100% yakin akan hal itu.

Ini ada hubungannya dengan baris dengan semua fitur nol/hilang.
Anda dapat melihat bahwa jika Anda memfilter dari kerangka data, baris dengan semua fitur nol - ini berfungsi dengan baik.

@ranInc Bisakah Anda memposting program Scala lengkap yang menghasilkan model? Pesan kesalahan tampaknya menunjukkan bahwa model Anda dilatih dengan satu fitur.

Saya tidak berpikir itu akan banyak membantu karena kodenya sangat umum dan memiliki beberapa transformator pendamaian,
Kode itu sendiri sebagian besar pyspark dan bukan scala.

Cara terbaik untuk melihat bahwa jumlah fitur tidak menjadi masalah adalah dengan memfilter baris dengan semua fitur nol dan menggunakan model - ini berfungsi tanpa masalah.
Anda juga dapat menyimpan semua baris dan mempartisi ulang kerangka data untuk menggunakan satu partisi, dan itu juga berfungsi.

@ranInc Saya memfilter baris dengan nol dan masih menghadapi kesalahan yang sama ( java.lang.NullPointerException ):

...
df.na.drop(minNonNulls = 1)
...

Bukankah ini cara yang tepat untuk melakukannya?

Saya tidak berpikir itu akan banyak membantu karena kodenya sangat umum dan memiliki beberapa transformer pendamaian

Saya ingin melihat berapa banyak fitur yang digunakan pada pelatihan dan waktu prediksi. Pesan kesalahan

ml.dmlc.xgboost4j.java.XGBoostError: [00:36:10] /workspace/src/learner.cc:1179: Pemeriksaan gagal: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Jumlah kolom tidak sesuai dengan jumlah fitur di booster.

menunjukkan bahwa model dilatih dengan satu fitur dan prediksi dibuat dengan dua fitur.

Saat ini, saya hanya memiliki akses ke bingkai data dan model serial yang Anda unggah. Saya tidak memiliki wawasan tentang apa yang masuk ke dalam pelatihan model dan apa yang salah, menghalangi saya untuk memecahkan masalah lebih jauh. Jika program Anda memiliki beberapa informasi kepemilikan, apakah mungkin untuk menghasilkan contoh yang bersih?

  1. tidak, Anda dapat melakukan ini:
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{callUDF, col}
.......
 val isVectorAllZeros = (col: Vector) => {
        col match {
          case sparse: SparseVector =>
            if(sparse.indices.isEmpty){
              true
            }else{
              false
            }
          case _ => false
        }
      }

      spark.udf.register("isVectorAllZeros", isVectorAllZeros)
      df = df.withColumn("isEmpty",callUDF("isVectorAllZeros",
        col("features_6620294785024229130"))).where("isEmpty == false")

anda juga dapat mempartisi ulang kerangka data seperti ini:

....
df = df.repartition(1)
....
  1. Saya mengerti, tetapi kodenya tidak akan memberi Anda banyak, karena menggunakan VectorAssembler, dan Anda tidak akan dapat mengetahui berapa banyak fitur yang sebenarnya digunakan,
    Tapi saya 100% yakin itu menggunakan jumlah fitur yang sama.

Tapi saya 100% yakin itu menggunakan jumlah fitur yang sama.

Bagaimana Anda memastikan ini, jika VectorAssembler menyebabkan jumlah fitur yang bervariasi?

VectorAssembler selalu membuat jumlah fitur yang sama, hanya perlu nama kolom untuk diambil.
Kode itu sendiri digunakan untuk membuat Ribuan model, jadi sangat umum dan pada dasarnya mendapatkan daftar nama untuk digunakan.

Saya mungkin dapat menjalankan pembuatan model lagi dan mengirimkan kerangka data yang digunakan untuk model tersebut - atau data lain yang Anda butuhkan.
Itu akan memakan waktu bagi saya dan jika Anda menggunakan apa yang saya tunjukkan sebelumnya, Anda akan melihat modelnya berfungsi dengan baik dengan 2 fitur.

@ranInc Izinkan saya mengajukan satu pertanyaan lagi: apakah benar untuk mengatakan bahwa contoh data memiliki kolom sparse (VectorAssembler) yang memiliki paling banyak dua fitur ?

Tidak.
VectorAssembler adalah Trasformer yang mengambil beberapa kolom dan menempatkannya dalam satu kolom Vektor.
Vektor selalu digunakan untuk model yang cocok dan memprediksi dalam percikan.

Contoh kerangka data di sini memiliki kolom vektor.
Beberapa baris jarang, yang lain padat - semuanya memiliki dua fitur.

@ranInc Jadi semua baris memiliki dua fitur, beberapa nilai hilang dan lainnya tidak. Mengerti. Saya akan mencoba saran Anda tentang memfilter baris kosong.

Seperti yang mungkin sudah Anda duga, saya cukup baru di ekosistem Spark, jadi upaya debugging mungkin terbukti cukup sulit. Kami saat ini membutuhkan lebih banyak pengembang yang tahu lebih banyak tentang pemrograman Spark dan Scala secara umum. Jika Anda secara pribadi mengenal seseorang yang ingin membantu kami meningkatkan paket JVM XGBoost, beri tahu kami.

@ranInc Saya mencoba memfilter baris kosong sesuai dengan saran Anda:

Program A: Contoh skrip, tanpa memfilter baris kosong

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

Program B: Contoh dengan pemfilteran baris kosong

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val isVectorAllZeros = (col: Vector) => {
    col match {
      case sparse: SparseVector => (sparse.indices.isEmpty)
      case _ => false
    }
  }
  spark.udf.register("isVectorAllZeros", isVectorAllZeros)

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
                .withColumn("isEmpty", callUDF("isVectorAllZeros", col("features_6620294785024229130")))
                .where("isEmpty == false")
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

Beberapa pengamatan

  • Dengan rilis 1.2.0 yang stabil, Program A error dengan java.lang.NullPointerException . Tepat sebelum NPE, peringatan berikut ditampilkan di log eksekusi Spark:
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1
  • Dengan rilis 1.2.0 yang stabil, Program B berhasil diselesaikan tanpa kesalahan.
  • Dengan versi pengembangan (cabang master terbaru, komit 42d31d9dcb6f7c1cb7d0545e9ab3a305ecad0816), Program A dan Program B gagal dengan kesalahan berikut:
[12:44:57] /home/ubuntu/xgblatest/src/learner.cc:1179: Check failed: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster.                                                                                                        
Stack trace:                                                                                                                                   
  [bt] (0) /tmp/libxgboost4j14081654332866852928.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x79) [0x7f7ef62c4e19]                             [bt] (1) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::ValidateDMatrix(xgboost::DMatrix*, bool) const+0x20b) [0x7f7ef63f5f0b]                                                                                                                                              
  [bt] (2) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::Predict(std::shared_ptr<xgboost::DMatrix>, bool, xgboost::HostDeviceVector<float>*, unsigned int, bool, bool, bool, bool, bool)+0x3c3) [0x7f7ef6400233]                                                             
  [bt] (3) /tmp/libxgboost4j14081654332866852928.so(XGBoosterPredict+0xec) [0x7f7ef62caa3c]                                                      [bt] (4) /tmp/libxgboost4j14081654332866852928.so(Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict+0x47) [0x7f7ef62befd7]             
  [bt] (5) [0x7f80908a8270]

yang aneh karena, menurut @ranInc , model dilatih dengan data dengan dua fitur.

  • Saya membuat versi 1.2.0-SNAPSHOT dari sumbernya (komit 71197d1dfa27c80add9954b10284848c1f165c40). Kali ini, Program A dan Program B gagal dengan kesalahan fitur yang tidak cocok ( learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster ).
  • Perbedaan perilaku antara versi stabil 1.2.0 dan 1.2.0-SNAPSHOT tidak terduga dan membuat saya cukup gugup. Secara khusus, pesan peringatan dari 1.2.0
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1

tidak ditemukan dalam basis kode C++ versi 1.2.0. Sebaliknya, peringatan ditemukan di cabang release_1.0.0 :
https://github.com/dmlc/xgboost/blob/ea6b117a5737f5beb2533fc89b3f3fcd72ecc04e/src/learner.cc#L972 -L982
Jadi apakah ini berarti file JAR 1.2.0 di Maven Central memiliki libxgboost4j.so dari 1.0.0 ?? 🤯 😱.

  • Memang, file JAR 1.2.0 dari Maven Central berisi libxgboost4j.so yang sebenarnya 1.0.0 (!!!). Untuk mengetahuinya, unduh xgboost4j_2.12-1.2.0.jar dari Maven Central dan ekstrak file libxgboost4j.so . Kemudian jalankan skrip Python berikut untuk memverifikasi versi file perpustakaan:
import ctypes

lib = ctypes.cdll.LoadLibrary('./libxgboost4j.so')

major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()

lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
print((major.value, minor.value, patch.value))  # prints (1, 0, 2), indicating version 1.0.2
  • Selain masalah 1.0.0, kami dengan jelas melihat bahwa model XGBoost yang terlatih hanya mengenali satu fitur ( learner_model_param_.num_feature == 1 ). Mungkin data pelatihan memiliki fitur yang 100% kosong?? @ranInc

Apakah Anda ingin saya mengambil kerangka data yang digunakan untuk membuat model?
Jika saya bisa mengambilnya, saya pikir saya bisa membuat kode scala sederhana yang membuat model.

@ranInc Kecurigaan saya adalah bahwa salah satu dari dua fitur dalam data pelatihan seluruhnya terdiri dari nilai-nilai yang hilang, menyetel learner_model_param_.num_feature ke 1. Jadi ya, melihat data pelatihan akan sangat membantu.

Baiklah, saya pikir saya akan menyiapkannya besok.

Dibuat #6426 untuk melacak masalah ketidakcocokan libxgboost4j.so . Di sini (#5957) mari kita lanjutkan diskusi tentang mengapa learner_model_param_.num_feature disetel ke 1.

Sepertinya Anda salah, data pelatihan tidak memiliki nilai yang hilang.
pada kode contoh di sini, alih-alih menyampaikan partisi ulang untuk mereproduksi kegagalan, saya malah menggunakan hanya satu baris (yang hanya memiliki nol fitur) untuk prediksi.

features_creation.zip

import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.DataFrame

val df = spark.read.parquet("/tmp/6620294785024229130_only_features_creation").persist()
df.count()

val regressor = new XGBoostRegressor()
    .setFeaturesCol("features_6620294785024229130")
    .setLabelCol("label_6620294785024229130")
    .setPredictionCol("prediction")
    .setMissing(0.0F)
    .setMaxDepth(3)
    .setNumRound(100)
    .setNumWorkers(1)

val pipeline = new Pipeline().setStages(Array(regressor))
val model = pipeline.fit(df)

val pred = spark.read.parquet("/tmp/6620294785024229130_features").persist()
pred.count()
pred.where("account_code == 4011593987").show()
model.transform(pred.where("account_code == 4011593987")).show()
Apakah halaman ini membantu?
0 / 5 - 0 peringkat