Xgboost: [jvm-packages] java.lang.NullPointerException:ml.dmlc.xgboost4j.java.Booster.predictでnull

作成日 2020年07月30日  ·  37コメント  ·  ソース: dmlc/xgboost

NPE例外は、JAVAAPIを介して予測された場合に発生します。

java.lang.NullPointerException:null
ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:309)で
ml.dmlc.xgboost4j.java.Booster.predict(Booster.java:375)で
com.tuhu.predict.predict.BaseModelPredict.predict(BaseModelPredict.java:71)で
com.tuhu.predict.predict.XgboostFindPageModelPredict.predict(XgboostFindPageModelPredict.java:53)で
com.tuhu.predict.service.impl.MlpFindPageFeatureServiceImpl.featureProcess(MlpFindPageFeatureServiceImpl.java:65)で
com.tuhu.predict.api.controller.MlpFindPageController.recommendPredict(MlpFindPageController.java:49)で
com.tuhu.predict.api.controller.MlpFindPageController $$ FastClassBySpringCGLIB $$ f694b9ff.invoke()。
org.springframework.cglib.proxy.MethodProxy.invoke(MethodProxy.java:204)で
org.springframework.aop.framework.CglibAopProxy $ CglibMethodInvocation.invokeJoinpoint(CglibAopProxy.java:746)で
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:163)で
org.springframework.aop.framework.adapter.MethodBeforeAdviceInterceptor.invoke(MethodBeforeAdviceInterceptor.java:52)で
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)で
org.springframework.aop.aspectj.AspectJAfterAdvice.invoke(AspectJAfterAdvice.java:47)で
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)で
org.springframework.aop.framework.adapter.AfterReturningAdviceInterceptor.invoke(AfterReturningAdviceInterceptor.java:52)で
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)で
org.springframework.aop.aspectj.AspectJAfterThrowingAdvice.invoke(AspectJAfterThrowingAdvice.java:62)で
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)で
org.springframework.aop.aspectj.MethodInvocationProceedingJoinPoint.proceed(MethodInvocationProceedingJoinPoint.java:88)で
com.tuhu.springcloud.common.annotation.AbstractControllerLogAspect.doAround(AbstractControllerLogAspect.java:104)で
sun.reflect.NativeMethodAccessorImpl.invoke0(ネイティブメソッド)で
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)で
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)で
java.lang.reflect.Method.invoke(Method.java:498)で
org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethodWithGivenArgs(AbstractAspectJAdvice.java:644)で
org.springframework.aop.aspectj.AbstractAspectJAdvice.invokeAdviceMethod(AbstractAspectJAdvice.java:633)で
org.springframework.aop.aspectj.AspectJAroundAdvice.invoke(AspectJAroundAdvice.java:70)で
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:174)で
org.springframework.aop.interceptor.ExposeInvocationInterceptor.invoke(ExposeInvocationInterceptor.java:92)で
org.springframework.aop.framework.ReflectiveMethodInvocation.proceed(ReflectiveMethodInvocation.java:185)で
org.springframework.aop.framework.CglibAopProxy $ DynamicAdvisedInterceptor.intercept(CglibAopProxy.java:688)で
com.tuhu.predict.api.controller.MlpFindPageController $$ EnhancerBySpringCGLIB $$ 560ed775.recommendPredict()。
sun.reflect.NativeMethodAccessorImpl.invoke0(ネイティブメソッド)で
sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62)で
sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43)で
java.lang.reflect.Method.invoke(Method.java:498)で
org.springframework.web.method.support.InvocableHandlerMethod.doInvoke(InvocableHandlerMethod.java:209)で
org.springframework.web.method.support.InvocableHandlerMethod.invokeForRequest(InvocableHandlerMethod.java:136)で
org.springframework.web.servlet.mvc.method.annotation.ServletInvocableHandlerMethod.invokeAndHandle(ServletInvocableHandlerMethod.java:102)で
org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.invokeHandlerMethod(RequestMappingHandlerAdapter.java:877)で
org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter.handleInternal(RequestMappingHandlerAdapter.java:783)で
org.springframework.web.servlet.mvc.method.AbstractHandlerMethodAdapter.handle(AbstractHandlerMethodAdapter.java:87)で
org.springframework.web.servlet.DispatcherServlet.doDispatch(DispatcherServlet.java:991)で
org.springframework.web.servlet.DispatcherServlet.doService(DispatcherServlet.java:925)で
org.springframework.web.servlet.FrameworkServlet.processRequest(FrameworkServlet.java:974)で
org.springframework.web.servlet.FrameworkServlet.doPost(FrameworkServlet.java:877)で
javax.servlet.http.HttpServlet.service(HttpServlet.java:661)で
org.springframework.web.servlet.FrameworkServlet.service(FrameworkServlet.java:851)で
javax.servlet.http.HttpServlet.service(HttpServlet.java:742)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:231)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
org.apache.tomcat.websocket.server.WsFilter.doFilter(WsFilter.java:52)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
com.tuhu.soter.starter.filter.SoterDefaultFilter.doFilter(SoterDefaultFilter.java:79)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
com.tuhu.boot.logback.filter.LogFilter.doFilter(LogFilter.java:54)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:158)で
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.filterAndRecordMetrics(WebMvcMetricsFilter.java:126)で
org.springframework.boot.actuate.metrics.web.servlet.WebMvcMetricsFilter.doFilterInternal(WebMvcMetricsFilter.java:111)で
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
org.springframework.boot.actuate.web.trace.servlet.HttpTraceFilter.doFilterInternal(HttpTraceFilter.java:90)で
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
com.tuhu.boot.common.filter.HeartbeatFilter.doFilter(HeartbeatFilter.java:42)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
com.tuhu.boot.common.filter.MDCFilter.doFilter(MDCFilter.java:47)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
org.springframework.web.filter.RequestContextFilter.doFilterInternal(RequestContextFilter.java:99)で
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
org.springframework.web.filter.HttpPutFormContentFilter.doFilterInternal(HttpPutFormContentFilter.java:109)で
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
org.springframework.web.filter.HiddenHttpMethodFilter.doFilterInternal(HiddenHttpMethodFilter.java:93)で
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
org.springframework.web.filter.CharacterEncodingFilter.doFilterInternal(CharacterEncodingFilter.java:200)で
org.springframework.web.filter.OncePerRequestFilter.doFilter(OncePerRequestFilter.java:107)で
org.apache.catalina.core.ApplicationFilterChain.internalDoFilter(ApplicationFilterChain.java:193)で
org.apache.catalina.core.ApplicationFilterChain.doFilter(ApplicationFilterChain.java:166)で
org.apache.catalina.core.StandardWrapperValve.invoke(StandardWrapperValve.java:198)で
org.apache.catalina.core.StandardContextValve.invoke(StandardContextValve.java:96)で
org.apache.catalina.authenticator.AuthenticatorBase.invoke(AuthenticatorBase.java:496)で
org.apache.catalina.core.StandardHostValve.invoke(StandardHostValve.java:140)で
org.apache.catalina.valves.ErrorReportValve.invoke(ErrorReportValve.java:81)で
org.apache.catalina.core.StandardEngineValve.invoke(StandardEngineValve.java:87)で
org.apache.catalina.valves.RemoteIpValve.invoke(RemoteIpValve.java:677)で
org.apache.catalina.connector.CoyoteAdapter.service(CoyoteAdapter.java:342)で
org.apache.coyote.http11.Http11Processor.service(Http11Processor.java:803)で
org.apache.coyote.AbstractProcessorLight.process(AbstractProcessorLight.java:66)で
org.apache.coyote.AbstractProtocol $ ConnectionHandler.process(AbstractProtocol.java:790)で
org.apache.tomcat.util.net.NioEndpoint $ SocketProcessor.doRun(NioEndpoint.java:1468)で
org.apache.tomcat.util.net.SocketProcessorBase.run(SocketProcessorBase.java:49)で
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)で
java.util.concurrent.ThreadPoolExecutor $ Worker.run(ThreadPoolExecutor.java:624)で
org.apache.tomcat.util.threads.TaskThread $ WrappingRunnable.run(TaskThread.java:61)で
java.lang.Thread.run(Thread.java:748)で

最も参考になるコメント

了解しました。明日までに準備できると思います。

全てのコメント37件

モデルはPythonSklearnを介してトレーニングされているため、後で非互換性が発生します。時間を節約するために、アルゴリズムチームはSklearnでトレーニングされたXGBモデルをPythonXgBoostパッケージの1つ上のレイヤーに移動しました。

image

どのバージョンのXGBoostを使用していますか? 以前、予測が失敗したときにjvmパッケージが例外を正しくスローせず、空の予測バッファーで続行するというバグを修正しました。

どのバージョンのXGBoostを使用していますか? 以前、予測が失敗したときにjvmパッケージが例外を正しくスローせず、空の予測バッファーで続行するというバグを修正しました。

同社のアルゴリズムプラットフォームのバージョン1.0が使用されており、バージョンの互換性の問題のため、アルゴリズムプロジェクトのバージョン0.9.0が使用されています。Algorithmの同僚はPythonを使用して1.0モデルファイルを0.9.0に変換しました。 この変容が原因なのかしら

1.2(https://github.com/dmlc/xgboost/issues/5734)を待ってから再試行することをお勧めします。このリリースには、いくつかの重要なバグ修正があります。 また、予測には同じまたはそれ以降のxgboostバージョンを使用することをお勧めします。 XGBoostのバイナリモデルは下位互換性があり、前方に進むと、JSONベースのモデルが推奨されます。

1.2.0でも同じ問題が発生しました。 したがって、問題はまだここにあります。

私も同じ問題を抱えていました。
xgboost4jを使用してモデルを作成しました。

回避策はありますか?

これは私にとって大きな問題であり、本番環境でのジョブに失敗しました。

@ranInc最新バージョンのXGBoostを使用していますか? これまでのところ、この問題の正確な原因はわかりません。 最善を尽くして対処しますが、いつ問題に対処できるかは保証されていないため、当面は別の方法を検討することをお勧めします。

@ranInc私たち(開発者)が自分のマシンで実行できる小さなサンプルプログラムを提供することで、私たちを助けることができます。

Mavenリポジトリの最新のjarである1.2.0を実行しています。
私にとっての代替手段は、saprk 2.4.5に戻り、xgboost0.9を使用することです。これが私が現在行っていることです。

例:後でジョブが失敗する原因となる特定のモデル/データを特定しようとします。

やあ、
特定のモデル/データを見つけました。
このコメントに添付します。
xgb1.2_bug.zip

これがバグを再現する方法です(ここで再パーティション化を行わないと、機能することに注意してください。したがって、各パーティションのデータの量またはデータのタイプと関係があります)。

from pyspark.ml.pipeline import PipelineModel
from pyspark.sql import SparkSession
from pyspark.sql.dataframe import DataFrame

df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
df.count()

model = PipelineModel.read().load("/tmp/6620294785024229130_model_xg_only")

predictions = model.transform(df)

predictions.persist()
predictions.count()
predictions.show()

これにいつ対処できるか、何か考えがありますか?
これにより、spark3.0を使用できなくなります...

@ranIncまだです。 バグの修正が完了したらお知らせします。 また、Scalaでコードを投稿できますか? XGBoostでのPySparkの使用を公式にサポートしたことはないと思います。

      import org.apache.spark.ml.{Pipeline, PipelineModel}
      val df = spark.read.parquet("/tmp/6620294785024229130_features").repartition(200).persist()
      df.count()

      val model = PipelineModel.read.load("/tmp/6620294785024229130_model_xg_only")

      val predictions = model.transform(df)

      predictions.persist()
      predictions.count()
      predictions.show()

別のポインタ、
問題は、予測されるように送信されているすべての機能がゼロ/欠落していることが原因のようです。

誰もこれに取り組んでいないと思いますか?
これは基本的に、xgboostがspark3ではまったく機能しないことを意味します。

ええ、申し訳ありませんが、私たちの手は今かなりいっぱいです。 この問題はいつか回避します。 何卒ご理解とご協力を賜りますようお願い申し上げます。 ありがとう。

@ranInc今日は時間があったので、ここで提供したスクリプトを実行してみました。 java.lang.NullPointerExceptionエラーを再現しました。

不思議なことに、最新の開発バージョン( masterブランチ)は同じようにクラッシュしません。 代わりに、エラーが発生します

スレッド"main"org.apache.spark.SparkExceptionの例外:ステージの失敗によりジョブが中止されました:ステージ7.0のタスク0が1回失敗し、最新の失敗:ステージ7.0のタスク0.0が失われました(TID 11、d04389c5babb、エグゼキュータードライバー): ml.dmlc.xgboost4j.java.XGBoostError:[00:36:10] /workspace/src/learner.cc:1179:チェックに失敗しました:learner_model_param_.num_feature> = p_fmat-> Info()。num_col_(1対2) :列の数がブースターの機能の数と一致しません。

さらに調査します。

エラーメッセージは今では理にかなっていると思います。入力には予測用のモデルよりも多くの機能があります。

xgboostが失敗した後、jvmパッケージが続行する前に、予測バッファーが空になります。 最近チェックガードを追加しました。

トレーニングデータセットの列数が予測データセット以上であることを確認してください。

やあ、
モデルは、同じ量の機能を使用して作成されました。
Sparkでは、複数の列ではなく1つのベクトル列を使用します。
いずれにせよ、フィッティングと予測のために、ベクトルのサイズは常に同じです-100%確実です。

これは、すべてのゼロ/欠落機能を備えた行と関係があります。
データフレームからすべてゼロの特徴を持つ行をフィルタリングすると、問題なく機能することがわかります。

@ranIncモデルを生成した完全なScalaプログラムを投稿できますか? エラーメッセージは、モデルが単一の機能でトレーニングされたことを示唆しているようです。

コードは非常に一般的であり、いくつかの適切なトランスフォーマーが含まれているため、あまり役に立たないと思います。
コード自体はほとんどがpysparkであり、scalaではありません。

特徴の数が問題ではないことを確認する最良の方法は、すべてゼロの特徴を持つ行をフィルターで除外し、モデルを使用することです。これは問題なく機能します。
また、すべての行を保持し、データフレームを再パーティション化して1つのパーティションを使用することもできます。これも機能します。

@ranIncゼロの行を除外しても、同じエラー( java.lang.NullPointerException )に直面しています。

...
df.na.drop(minNonNulls = 1)
...

これは正しい方法ではありませんか?

コードは非常に一般的であり、いくつかの適切なトランスフォーマーが含まれているため、あまり役に立たないと思います

トレーニング時と予測時に使用されている機能の数を確認したいと思います。 エラーメッセージ

ml.dmlc.xgboost4j.java.XGBoostError:[00:36:10] /workspace/src/learner.cc:1179:チェックに失敗しました:learner_model_param_.num_feature> = p_fmat-> Info()。num_col_(1対2) :列の数がブースターの機能の数と一致しません。

モデルが単一の機能でトレーニングされ、予測が2つの機能で行われていることを示しています。

現在、アップロードしたデータフレームとシリアル化されたモデルにしかアクセスできません。 モデルトレーニングに何が入ったのか、何がうまくいかなかったのかについての洞察が不足しているため、問題のトラブルシューティングをこれ以上行うことができません。 プログラムに独自の情報がある場合、クリーンな例を作成することは可能ですか?

  1. いいえ、これを行うことができます:
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{callUDF, col}
.......
 val isVectorAllZeros = (col: Vector) => {
        col match {
          case sparse: SparseVector =>
            if(sparse.indices.isEmpty){
              true
            }else{
              false
            }
          case _ => false
        }
      }

      spark.udf.register("isVectorAllZeros", isVectorAllZeros)
      df = df.withColumn("isEmpty",callUDF("isVectorAllZeros",
        col("features_6620294785024229130"))).where("isEmpty == false")

次のようにデータフレームを再パーティション化することもできます。

....
df = df.repartition(1)
....
  1. わかりましたが、VectorAssemblerを使用しているため、コードではあまり効果がなく、実際に使用された機能の数を知ることはできません。
    しかし、私はそれが同じ量の機能を使用したと100%確信しています。

しかし、私はそれが同じ量の機能を使用したと100%確信しています。

VectorAssemblerによってさまざまな数の機能が使用される場合、これをどのように確認しましたか?

VectorAssemblerは常に同じ量の機能を作成し、取得する列の名前が必要なだけです。
コード自体は数千のモデルを作成するために使用されるため、非常に一般的であり、基本的に使用する名前のリストを取得します。

モデルの作成を再度実行して、モデルに使用されているデータフレームまたはその他の必要なデータを送信できる場合があります。
それには時間がかかりますが、前に示したものを使用すると、モデルが2つの機能で問題なく動作することがわかります。

@ranIncもう1つ質問させてください。サンプルデータには、最大2つの機能を持つスパース列(VectorAssembler)があると言うのは正しいですか?

いいえ。
VectorAssemblerは、複数の列を取得して1つのVector列に配置するTrasformerです。
ベクトルは、スパークに適合して予測するモデルに常に使用されます。

ここでのサンプルデータフレームには、ベクトル列があります。
一部の行はまばらで、他の行は密です-すべてに2つの機能があります。

@ranIncしたがって、すべての行には2つの機能があり、一部の値が欠落しており、他の値は欠落しています。 とった。 空の行のフィルタリングについての提案を試してみます。

ご想像のとおり、私はSparkエコシステムにまったく慣れていないため、デバッグ作業は非常に難しい場合があります。 現在、SparkおよびScalaプログラミング全般について詳しく知っている開発者を増やす必要があります。 XGBoostのJVMパッケージの改善を手伝ってくれる人を個人的に知っている場合は、お知らせください。

@ranInc私はあなたの提案に従って空の行をフィルタリングしようとしました:

プログラムA:空の行をフィルタリングしないサンプルスクリプト

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

プログラムB:空の行フィルタリングの例

import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.linalg.{SparseVector, Vector}
import org.apache.spark.sql.functions.{callUDF, col}

object Main extends App {
  val spark = SparkSession
      .builder()
      .appName("XGBoost4J-Spark Pipeline Example")
      .getOrCreate()

  val isVectorAllZeros = (col: Vector) => {
    col match {
      case sparse: SparseVector => (sparse.indices.isEmpty)
      case _ => false
    }
  }
  spark.udf.register("isVectorAllZeros", isVectorAllZeros)

  val df = spark.read.parquet("/home/ubuntu/data/6620294785024229130_features").repartition(200).persist()
                .withColumn("isEmpty", callUDF("isVectorAllZeros", col("features_6620294785024229130")))
                .where("isEmpty == false")
  df.show()

  val model = PipelineModel.read.load("/home/ubuntu/data/6620294785024229130_model_xg_only")

  val predictions = model.transform(df)

  predictions.persist()
  predictions.count()
  predictions.show()
}

いくつかの観察

  • 安定した1.2.0リリースでは、プログラムAjava.lang.NullPointerExceptionでエラーになります。 NPEの直前に、次の警告がSpark実行ログに表示されます。
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1
  • 安定した1.2.0リリースでは、プログラムBはエラーなしで正常に完了します。
  • 開発バージョン(最新のmasterブランチ、コミット42d31d9dcb6f7c1cb7d0545e9ab3a305ecad0816)では、プログラムAプログラムBの両方が次のエラーで失敗します。
[12:44:57] /home/ubuntu/xgblatest/src/learner.cc:1179: Check failed: learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster.                                                                                                        
Stack trace:                                                                                                                                   
  [bt] (0) /tmp/libxgboost4j14081654332866852928.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x79) [0x7f7ef62c4e19]                             [bt] (1) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::ValidateDMatrix(xgboost::DMatrix*, bool) const+0x20b) [0x7f7ef63f5f0b]                                                                                                                                              
  [bt] (2) /tmp/libxgboost4j14081654332866852928.so(xgboost::LearnerImpl::Predict(std::shared_ptr<xgboost::DMatrix>, bool, xgboost::HostDeviceVector<float>*, unsigned int, bool, bool, bool, bool, bool)+0x3c3) [0x7f7ef6400233]                                                             
  [bt] (3) /tmp/libxgboost4j14081654332866852928.so(XGBoosterPredict+0xec) [0x7f7ef62caa3c]                                                      [bt] (4) /tmp/libxgboost4j14081654332866852928.so(Java_ml_dmlc_xgboost4j_java_XGBoostJNI_XGBoosterPredict+0x47) [0x7f7ef62befd7]             
  [bt] (5) [0x7f80908a8270]

@ranIncによると、モデルは2つの機能を備えたデータでトレーニングされているため、これは奇妙なことです。

  • ソースからバージョン1.2.0-SNAPSHOTをビルドしました(コミット71197d1dfa27c80add9954b10284848c1f165c40)。 今回は、プログラムAプログラムBの両方が機能の不一致エラー( learner_model_param_.num_feature >= p_fmat->Info().num_col_ (1 vs. 2) : Number of columns does not match number of features in booster )で失敗します。
  • 安定した1.2.0バージョンと1.2.0-SNAPSHOTの動作の違いは予想外であり、私は非常に緊張しました。 特に、1.2.0からの警告メッセージ
WARNING: /xgboost/src/learner.cc:979: Number of columns does not match number of features in booster. Columns: 0 Features: 1

C++コードベースの1.2.0バージョンにはありません。 代わりに、警告はrelease_1.0.0ブランチにあります。
https://github.com/dmlc/xgboost/blob/ea6b117a5737f5beb2533fc89b3f3fcd72ecc04e/src/learner.cc#L972 -L982
つまり、MavenCentralの1.2.0JARファイルに1.0.0からlibxgboost4j.soがあるということですか? 🤯😱

  • 実際、MavenCentralの1.2.0JARファイルにはlibxgboost4j.soが含まれていますが、これは実際には1.0.0(!!!)です。 調べるには、 Maven Centralからxgboost4j_2.12-1.2.0.jarをダウンロードし、 libxgboost4j.soファイルを抽出します。 次に、次のPythonスクリプトを実行して、ライブラリファイルのバージョンを確認します。
import ctypes

lib = ctypes.cdll.LoadLibrary('./libxgboost4j.so')

major = ctypes.c_int()
minor = ctypes.c_int()
patch = ctypes.c_int()

lib.XGBoostVersion(ctypes.byref(major), ctypes.byref(minor), ctypes.byref(patch))
print((major.value, minor.value, patch.value))  # prints (1, 0, 2), indicating version 1.0.2
  • 1.0.0の問題はさておき、トレーニングされたXGBoostモデルが認識した機能は1つだけであることがはっきりとわかります( learner_model_param_.num_feature == 1 )。 たぶん、トレーニングデータには100%空の機能がありましたか? @ranInc

モデルの作成に使用したデータフレームを取得しますか?
それをつかむことができれば、モデルを作成する簡単なscalaコードを作成できると思います。

@ranInc私の疑いは、トレーニングデータの2つの機能の1つが完全に欠落した値で構成されており、 learner_model_param_.num_featureを1に設定していることです。したがって、トレーニングデータを確認することは非常に役立ちます。

了解しました。明日までに準備できると思います。

libxgboost4j.soの不一致の問題を追跡するために#6426を作成しました。 ここ(#5957)では、 learner_model_param_.num_featureが1に設定されている理由について説明し続けましょう。

あなたが間違っているようです、トレーニングデータには欠測値がありません。
ここのサンプルコードでは、失敗を再現するために再パーティションを中継する代わりに、予測に1つの行(機能がゼロのみ)のみを使用しました。

features_creation.zip

import ml.dmlc.xgboost4j.scala.spark.XGBoostRegressor
import org.apache.spark.ml.Pipeline
import org.apache.spark.sql.SparkSession
import org.apache.spark.ml.PipelineModel
import org.apache.spark.sql.DataFrame

val df = spark.read.parquet("/tmp/6620294785024229130_only_features_creation").persist()
df.count()

val regressor = new XGBoostRegressor()
    .setFeaturesCol("features_6620294785024229130")
    .setLabelCol("label_6620294785024229130")
    .setPredictionCol("prediction")
    .setMissing(0.0F)
    .setMaxDepth(3)
    .setNumRound(100)
    .setNumWorkers(1)

val pipeline = new Pipeline().setStages(Array(regressor))
val model = pipeline.fit(df)

val pred = spark.read.parquet("/tmp/6620294785024229130_features").persist()
pred.count()
pred.where("account_code == 4011593987").show()
model.transform(pred.where("account_code == 4011593987")).show()
このページは役に立ちましたか?
0 / 5 - 0 評価