Scikit-learn: Fehler bei der Netzsuche in der Pipeline mit Keine für Transformatorschritt

Erstellt am 12. Nov. 2020  ·  3Kommentare  ·  Quelle: scikit-learn/scikit-learn

Beschreibe den Fehler

Beim Durchführen einer Rastersuche in einer Pipeline mit None für einen Transformer-Schritt wird ein AttributeError ausgelöst. Dieses Snippet unten lief zuvor erfolgreich mit scikit-learn==0.23.2 , funktioniert aber nicht mehr mit 0.24.dev0 .

Schritte/Code zum Reproduzieren

from sklearn.datasets import load_iris
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC

iris = load_iris()
X, y = iris.data, iris.target

pipe = Pipeline([("setup", None), ("svc", SVC(kernel="linear", random_state=0))])

param_grid = [
    {"svc__C": [0.1, 0.1]},
    {"setup": [StandardScaler()]},
]

gs = GridSearchCV(pipe, param_grid=param_grid, return_train_score=True, cv=3)
gs.fit(X, y)

erwartete Ergebnisse

Beispiel: Es wird kein Fehler ausgegeben. Bitte fügen Sie die erwarteten Ergebnisse ein oder beschreiben Sie sie.

Der GridSearchCV.fit Anruf kann erfolgreich abgeschlossen werden

Tatsächliche Ergebnisse

Bitte fügen Sie die tatsächliche Ausgabe oder das Traceback ein oder beschreiben Sie sie spezifisch.

Der folgende Fehler wird ausgegeben (den vollständigen Traceback habe ich weiter unten eingefügt):

  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/base.py", line 863, in _is_pairwise
    pairwise_tag = estimator._get_tags().get('pairwise', False)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/base.py", line 348, in _get_tags
    more_tags = base_class._more_tags(self)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/pipeline.py", line 626, in _more_tags
    estimator_tags = self.steps[0][1]._get_tags()
AttributeError: 'NoneType' object has no attribute '_get_tags'

Es scheint, dass die _is_pairwise Prüfung nicht wie erwartet funktioniert, wenn sie auf eine Pipeline mit None für einen Stufentransformator angewendet wird.


Vollständige Rückverfolgung:

Traceback (most recent call last):
  File "test-pipeline.py", line 18, in <module>
    gs.fit(X, y)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/utils/validation.py", line 60, in inner_f
    return f(*args, **kwargs)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/model_selection/_search.py", line 841, in fit
    self._run_search(evaluate_candidates)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/model_selection/_search.py", line 1288, in _run_search
    evaluate_candidates(ParameterGrid(self.param_grid))
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/model_selection/_search.py", line 795, in evaluate_candidates
    out = parallel(delayed(_fit_and_score)(clone(base_estimator),
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/joblib/parallel.py", line 1048, in __call__
    if self.dispatch_one_batch(iterator):
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/joblib/parallel.py", line 866, in dispatch_one_batch
    self._dispatch(tasks)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/joblib/parallel.py", line 784, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 208, in apply_async
    result = ImmediateResult(func)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/joblib/_parallel_backends.py", line 572, in __init__
    self.results = batch()
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/joblib/parallel.py", line 262, in __call__
    return [func(*args, **kwargs)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/joblib/parallel.py", line 262, in <listcomp>
    return [func(*args, **kwargs)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/utils/fixes.py", line 222, in __call__
    return self.function(*args, **kwargs)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/model_selection/_validation.py", line 585, in _fit_and_score
    X_train, y_train = _safe_split(estimator, X, y, train)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/utils/metaestimators.py", line 198, in _safe_split
    if _is_pairwise(estimator):
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/base.py", line 863, in _is_pairwise
    pairwise_tag = estimator._get_tags().get('pairwise', False)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/base.py", line 348, in _get_tags
    more_tags = base_class._more_tags(self)
  File "/Users/james/miniforge3/envs/dask-ml/lib/python3.8/site-packages/sklearn/pipeline.py", line 626, in _more_tags
    estimator_tags = self.steps[0][1]._get_tags()
AttributeError: 'NoneType' object has no attribute '_get_tags'

Versionen

System:
    python: 3.8.6 | packaged by conda-forge | (default, Oct  7 2020, 18:42:56)  [Clang 10.0.1 ]
executable: /Users/james/miniforge3/envs/dask-ml/bin/python3.8
   machine: macOS-10.15.5-x86_64-i386-64bit

Python dependencies:
          pip: 20.2.4
   setuptools: 49.6.0.post20201009
      sklearn: 0.24.dev0
        numpy: 1.19.4
        scipy: 1.5.3
       Cython: None
       pandas: 1.1.4
   matplotlib: None
       joblib: 0.17.0
threadpoolctl: 2.1.0

Built with OpenMP: True
Blocker Bug

Hilfreichster Kommentar

Ich werde es als Blocker markieren, da der Fehler nicht nur bei der Verwendung von None auftritt, sondern auch bei der Verwendung eines Schritts, der kein _get_tags Attribut hat (wahrscheinlich, weil es nicht von BaseEstimator )

Alle 3 Kommentare

Danke für den Bericht @jrbourbeau , den wir reproduzieren können. Wir untersuchen die beste Lösung in den verschiedenen oben verlinkten Problemen, wenn Sie interessiert sind

Ich werde es als Blocker markieren, da der Fehler nicht nur bei der Verwendung von None auftritt, sondern auch bei der Verwendung eines Schritts, der kein _get_tags Attribut hat (wahrscheinlich, weil es nicht von BaseEstimator )

Behoben durch #18797. Danke für den zeitnahen Fehlerbericht @jrbourbeau .

War diese Seite hilfreich?
0 / 5 - 0 Bewertungen