Scikit-learn: Suggestion: Remove prediction from plot_confusion_matrix and just pass predicted labels

Created on 13 Dec 2019  ·  61Comments  ·  Source: scikit-learn/scikit-learn

The signature of plot_confusion_matrix is currently:

sklearn.metrics.plot_confusion_matrix(estimator, X, y_true, labels=None, sample_weight=None, normalize=None, display_labels=None, include_values=True, xticks_rotation='horizontal', values_format=None, cmap='viridis', ax=None)

The function takes an estimator and raw data and can not be used with already predicted labels. This has some downsides:

  • If a confusion matrix should be plotted but the predictions should also be used elsewhere (e.g. calculating accuracy_score) the estimation has to be performed several times. That takes longer and can result in different values if the estimator is randomized.
  • If no estimator is available (e.g. predictions loaded from a file) the plot can not be used at all.

Suggestion: allow passing predicted labels y_pred to plot_confusion_matrix that will be used instead of estimator and X. In my opinion the cleanest solution would be to remove the prediction step from the function and use a signature similar to that of accuracy_score, e.g. (y_true, y_pred, labels=None, sample_weight=None, ...). However in order to maintain backwards compatibility, y_pred can be added as an optional keyword argument.

model_selection

All 61 comments

We should definitely stay backward compatible, but adding a y_pred keyword arg sounds reasonable to me. We should raise an error if y_pred is passed but X or estimator are also passed.

Would you want to submit a PR @jhennrich ?

I submitted a PR, but I think there is currently a problem with the CI so it has not passed yet.

I agree that we should support plot_XXX(y_true, y_pred) to avoid calculating the prediction for multiple times.
We also have similar issues in plot_roc_curve and plot_precision_recall_curve.
Adding y_pred seems acceptable, but honestly I don't think it's a good solution.
For those functions which accept **kwargs (e.g., plot_precision_recall_curve), seems that it's impossible to keep backward compatible?

Why is it impossible to keep the backward compatibility? It seems to me that the proposal in #15883 is OK

Why is it impossible to keep the backward compatibility? It seems to me that the proposal in #15883 is OK

because we do not support **kwargs in plot_confusion_matrix. @NicolasHug

Why is kwargs a problem?

Hmm, so there's another annoying thing, we support **kwargs in plot_roc_curve and plot_precision_recall_curve (and plot_partial_dependence), but we do not support it in plot_confusion_matrix

Why is kwargs a problem?

if we add the new parameter before **kwargs, we can keep backward compatibility, right?

The changes in my PR are backwards compatible and **kwargs can still be added. But I agree with @qinhanmin2014, a much much cleaner solution would be to throw out estimator and X and use positional arguments (y_true, y_pred, ...) that are consistent with most of the other sklearn stuff.

if we add the new parameter before **kwargs, we can keep backward compatibility, right?

yes

a much much cleaner solution....

Unfortunately that would require a deprecation cycle (unless we make it very fast in the bugfix release but I doubt it...)

@thomasjpfan , any reason to pass the estimator as input instead of the predictions?

Thanks, let's add y_pred first, **kwags is another issue.

Unfortunately that would require a deprecation cycle (unless we make it very fast in the bugfix release but I doubt it...)

This seems impossible, sigh

@thomasjpfan , any reason to pass the estimator as input instead of the predictions?

I agree that we need to reconsider our API design. also try to ping @amueller

If a user wants to provide their own plotting part and provide their own confusion matrix:

from sklearn.metrics import ConfusionMatrixDisplay
confusion_matrix = confusion_matrix(...)
display_labels = [...]

disp = ConfusionMatrixDisplay(confusion_matrix=confusion_matrix,
                              display_labels=display_labels)
disp.plot(...)

This can similar be done for the other metric plotting functions.

The plot_confusion_matrix is kind of designed like the scorers which are able to handle the output of estimators nicely. In other words, it is a convenience wrapper for interacting with ConfusionMatrixDisplay and the estimator.

By accepting the estimator first, there is a uniform interface for the plotting functions. For example, the plot_partial_dependence does all the computation needed for creating the partial dependence plots and passes it to PartialDependenceDisplay. A user can still create the PartialDependenceDisplay themselves, but in that case it would be more invovled.

Although, I am open to having a "fast path", allowing for y_pred to be passed into the metrics related plotting functions, which will be passed directly to confusion_matrix and let it deal with validation.

The computation of the predictions needed to build a PDPs are quite complex. Also, these predictions are typically unusable in e.g. a scorer or a metric. They're only useful for plotting the PDP. So it makes sense in this case to only accept the estimator in plot_partial_dependence.

OTOH for confusion matrix, the predictions are really just est.predict(X).

I don't think we want a uniform interface here. These are 2 very different input use-cases

EDIT: In addition, the tree-based PDPs don't even need predictions at all

There are other things we will run into without the estimator. For example if plot_precision_recall_curve were to accept y_pred, it will need pos_label because it can not be inferred anymore. In this case, I would prefer to use PrecisionRecallDisplay directly and have the user calculate the parameters needed to reconstruct the plot.

This comes down to what kind of question we are answering with this API. The current interface revolves around evaluating an estimator, thus using the estimator as an argument. It is motivated by answering "how does this trained model behave with this input data?"

If we accept y_pred, y_true, now the question becomes "how does this metric behave with this data?" This data may or may not be generated by a model.

It's true that in this specific case, @jhennrich you could directly be using the ConfusionMatrixDisplay.

One drawback is that you need to specify display_labels since it has no default.

@thomasjpfan do you think we could in general provide sensible defaults for the Display objects, thus still making the direct use of the Display objects practical?

For some parameters, like display_labels, there is a reasonable default. The other Display object parameters can have reasonable defaults as well. Some parameters must be provided tho. For example, confusion_matrix must be provided for ConfusionMatrixDisplay or precision and recall for PrecisionRecallDisplay.

One classic pattern for this kind of thing is defining:

ConfusionMatrixDisplay.from_estimator(...)
ConfusionMatrixDisplay.from_predictions(...)

but this is not very idiomatic to scikit-learn.

I start to get confused. The goal of current API is to avoid calculating for multiple times if users want to plot for multiple times, but if we accept y_true and y_pred, users still don't need to calculate for multiple times? (I know that things are different in PDP)

@jnothman That API is pretty nice looking!

@qinhanmin2014 Passing an estimator, X, y or y_true, y_pred works in satisfying the "do not calculate multiple times" API. In both cases, the confusion matrix is computed and stored into the Display object.

The difference between them is where the calculation of confusion matrix starts. One can think of pass y_pred as the "precomputed" value of the estimator.

So I think y_true, y_pred is better than estimator, X, y (not in PDP of course), because sometimes (often?) users not only want to plot the predictions, they also want to analysis the predictions. With current API, they'll need to calculate the predictions for multiple times.

For metrics, I can see the preference toward using y_true, y_pred over estimator, X, y. Imagine if the plotting for metrics support only y_true, y_pred

est = # fit estimator

plot_partial_dependence(est, X, ...)

# if plot_confusion_matrix accepts `y_true, y_pred`
y_pred = est.predict(X)
plot_confusion_matrix(y_true, y_pred, ...)

# if plot_roc_curve supports `y_true, y_score`
y_score = est.predict_proba(X)[: , 1]
plot_roc_curve(y_true, y_score, ...)
plot_precision_recall_curve(y_true, y_score, ...)

Currently the API looks like:

est = # fit estimator
plot_partial_dependence(est, X, ...)
plot_confusion_matrix(est, X, y, ...)
plot_roc_curve(est, X, y, ...)

# this will call `predict_proba` again
plot_precision_recall_curve(est, X, y, ...)

I would prefer to have an API that supports both options (somehow).

For metrics, I can see the preference toward using y_true, y_pred over estimator, X, y. Imagine if the plotting for metrics support only y_true, y_pred

Yes, this is what I mean.

I would prefer to have an API that supports both options (somehow).

I think this is a practical solution. An annoying thing is that we can only add y_pred at the end (i.e., plot_confusion_matrix(estimator, X, y_true, ..., y_pred))

Yup it will be at the end and the API would look like this:

plot_confusion_matrix(y_true=y_true, y_pred=y_pred, ...)

which I think I am okay with. This is essentially the PR https://github.com/scikit-learn/scikit-learn/pull/15883

Yup it will be at the end and the API would look like this plot_confusion_matrix(y_true=y_true, y_pred=y_pred, ...)

I guess you mean that we should add y_true and remove est & X, right? I guess it's impossible? (because we can only add y_pred at the end)

Do we want to solve this in 0.22.1? @NicolasHug @thomasjfox I think it's worthwhile to put this in 0.22.1, but at the same time, seems that this is a new feature.

No, don't put it in 0.22.1. it is a clear violation of semver

@qinhanmin2014 Adding y_pred at the end or removing est, X seems like a new feature that belongs in the next release.

I guess you mean that we should add y_true and remove est & X, right? I guess it's impossible?

In the end I would prefer to support have both interfaces, because they have slightly different use case.

  1. est, X is easier to do quick analysis, because the function handles choosing the response function, slicing the result and passing it to the metric.
  2. y_true, y_pred is for users that understand how to work with the underlying metric and have the predictions already saved.

What's the problem with doing https://github.com/scikit-learn/scikit-learn/issues/15880#issuecomment-565489619 ?

I haven't read this whole thread but if we allow the interface here, we also need to do it for plot_roc_curve where the interface will be quite different between providing predictions and providing the estimator (one needs pos_label the other doesn't).
So I think allowing both in the same interface is a bad idea (someone will pass pos_label when passing an estimator and get a result they don't expect).

ConfusionMatrixDisplay.from_estimator(...)
ConfusionMatrixDisplay.from_predictions(...)

Could work, but it would basically make the plot_confusion_matrix redundant, and so we would remove the functions again and change the responsibilities between the class and the function (we said the class doesn't do the compute).

If we want to add a from_predictions to plot_roc_curve it needs to basically mirror the roc_curve interface perfectly. So I don't think it's too bad to have the user call the roc_curve function directly and then pass the results to the Display object.

The whole purpose of the design of the display objects was to allow the usecase mentioned by @jhennrich and why we separated the calculation from the function. I haven't seen an argument on why we should back down on that decision yet.

@amueller Technically you are right, the current solution to my Issue is to just use ConfusionMatrixDisplay. However it is really clumsy to use:

  • you have to pass the labels explicitly
  • you have to calculate the confusion matrix first
  • you have to create an object of the class and then still call the plot method

For all the applications I can think of a plot_confusion_matrix signature with (y_true, y_pred, ...) would be much more convenient than what we currently have. In my opinion there are much more use-cases where you want to explicitly calculate the predictions (though I am sure my view is biased).

If you have a plot_confusion_matrix(y_true, y_pred) signature and you actually want to use it on estimator, x, y data, there is only very little additional code to write: plot_confusion_matrix(y, estimator.predict(x)).
In comparison, if you have the current signature and you want to plot from y_true and y_pred, you have to write a lot more code.

In my opinion the plot_confusion_matrix(y_true, y_pred) signature should be default and another function that takes estimator, x, y should be built on top.

Last but not least, I honestly don't really understand the idea behind the ConfusionMatrixDisplay class. The function only has a single constructor and exactly one method, so whenever you use it you end up creating an instance and calling the plot function. I don't see why this should be a class and not just a function. Also there are other *Display classes (PrecisionRecall, ROC, ...), but their constructor- and plot() signatures are completely different, so they can not be swapped out anyway.
Maybe this goes beyond the scope of this issue.

@jhennrich

If you have a plot_confusion_matrix(y_true, y_pred) signature and you actually want to use it on estimator, x, y data, there is only very little additional code to write: plot_confusion_matrix(y, estimator.predict(x)).

For the confusion matrix case, it is simple to pass in estimator.predict if we had a y_true, y_pred interface. On the other hand, for plot_roc_auc, the user would need to do slicing:

y_pred = est.predict_proba(X)
plot_roc_curve(y_true, y_pred[:, 1])

# or
y_pred = est.decision_function(X)
plot_roc_curve(y_true, y_pred[:, 1])

Last but not least, I honestly don't really understand the idea behind the ConfusionMatrixDisplay class. The function only has a single constructor and exactly one method, so whenever you use it you end up creating an instance and calling the plot function. I don't see why this should be a class and not just a function.

The purpose of the Display objects is to store the computed values allowing for the users to call plot many times without recomputing. This can be seen by using plot_partial_dependence:

# Does expensive computation
disp = plot_partial_dependence(est, ...)

# change line color without needing to recompute partial dependence
disp.plot(line_kw={"c": "red"})

Honestly, I am on the fence about this issue. I am +0.1 to moving toward copy the metrics interface for metrics plotting and removing the est, X, y interface. :/

For the confusion matrix case, it is simple to pass in estimator.predict if we had a y_true, y_pred interface. On the other hand, for plot_roc_auc, the user would need to do slicing:

Yes, but by doing so, we avoid calculating the prediction for multiple times (though predicting is often not so expensive)

Perhaps a practical solution it to support y_true, y_pred in plot_XXX (when applicable) in 0.23.

@jhennrich How are you going to do this without passing the labels explicitly? If the labels can be inferred from what is given confusion_matrix will do that for you.

But indeed you're right, it's three lines instead of one.

In the case of confusion_matrix I tend to agree that the more common case might be passing y_true and y_pred.
The reason the interface currently is the way it is is to be consistent with the other metric plotting functions. As @thomasjpfan said, the roc curve is less obvious to plot.

Right now the code for plotting a confusion matrix and plotting a roc curve are the same. With your suggested change, they won't be the same any more, and there won't be an easy way to make them the same.

The question is whether in this case it's better to have consistent interfaces or to have a simple interface.
@jhennrich To me the real question is what the right interface for plot_roc_curve is. Do you have thoughts on that?

@thomasjpfan do you lean towards taking y_store for plotting roc auc as well?

There's certainly pros and cons for using the scorer interface instead of using the metric interface. But for more complex things it's much safer to use the scorer interface.

@qinhanmin2014
I think it would be fine to add y_pred to plot_confusion_matrix. The question is whether we want to add y_score to plot_roc_curve and plot_precision_recall_curve. If we do then we also have to add pos_label as I said above, and things will become more complicated.

I see three ways out of this:
a) Only add y_pred to plot_confusion_matrix, but don't add y_score to plot_roc_curve etc. Downside: the problem of calling predict_proba multiple times keeps existing for these metrics.
b) Make it easier to use the Display object directly (though I don't really know how).
c) Add another method or function that mirrors the metric interface. Downside: bigger API surface.

I don't think that having the plot_X function mirror both the scorer and metric interface at the same time is a good idea in general.

I think it would be great to resolve this in some way @adrinjalali do you want to discuss it in the next meeting maybe?

I sometimes have nightmares about this issue. Maybe we can add a static method that takes the output of the metric directly:

result = confusion_matrix(...)
ConfusionMatrixDisplay.from_metric(result).plot()

For roc curve:

result = roc_curve(...)
RocCurveDisplay.from_metric(*result).plot()

On a side note, from looking at codebases, I think more users are familiar with the metrics interface than the score interface.

I sometimes have nightmares about this issue.

Oh no :(

On a side note, from looking at codebases, I think more users are familiar with the metrics interface than the score interface.

I think this is definitely true. But I'm also quite certain that people use y_pred when they should be using y_score and are getting wrong results because the interface doesn't tell you that you need to do something different and no-one ever reads the docs.

I'm not sure how the static method you propose is different from the constructor but maybe I'm overlooking something.

Hi, I've just up-voted the issue - as a long-time sklearn user, I found the current API for plot_confusion_matrix very... well, confusing. I really like its addition (less copy-pasting), but the metrics functions always used the (y_true, y_pred) scheme which is just more flexible and what I have already been used to.

In my case it doesn't make sense to pass an estimator in, as it's a very slow model and I'd rather load the predictions from a file than re-run it every time I want to analyze the results. I'm happy to have found out in this thread there's a work-around using the *Display object, but its discoverability is not great - I would suggest at least adding that to plot_confusion_matrix documentation or maybe confusion matrix user guide?

In my case it doesn't make sense to pass an estimator in, as it's a very slow model and I'd rather load the predictions

Thanks for your input. If the current API is confusing, it would increasing make more sense to move to a more metrics-API like interface and go through a painful deprecation cycle.

The biggest concern we have with using the metrics interface is:

But I'm also quite certain that people use y_pred when they should be using y_score and are getting wrong results because the interface doesn't tell you that you need to do something different and no-one ever reads the docs.

@pzelasko What are your thoughts on this matter?

@thomasjpfan I understand the issue, it's a tough one. Maybe a reasonable compromise would be to allow only keyword arguments for this function (now that you don't have to support Python 2 anymore)? Like: def plot_confusion_matrix(*, y_true, y_pred, ...). It's still different from the rest of the metrics, but 1) it has a good reason for that, 2) it is at least using the same type of inputs as the other functions.

Anyway, I know why you're hesitant to make any API changes, that's why I suggested to at least mention the work-around in docs. (I've actually read them numerous times and I really appreciate them!)

The current way to use y_true and y_pred is shown here: https://scikit-learn.org/stable/auto_examples/miscellaneous/plot_display_object_visualization.html#create-confusionmatrixdisplay

I know I am stretching here but what about this:

plot_confusion_matrix(estimator='precomputed', y_true, y_pred, ...)

where the second position accepts y_true as predictions if estimator='precomputed.

if you want to stretch even more I would prefer plot_confusion_matrix((estimator, X, y), ...) or plot_confusion_matrix((y_true, y_pred), ...) but I am not sure it is solving the issues raised by @amueller regarding the metric-like API

There are a few new plotting utilities where allowing a metric API would really make sense:

I understand the issue that @amueller mentioned about needing to pass pos_label etc., but this isn't an issue for any of the aforementionned functions.

Are we OK to support both scorer and metrics API for these two? We don't need to worry about backward compatibility there.

I am still for my suggestion of using precomputed, which we commonly use in our estimators. In this case, the signature would be:

plot_confusion_matrix(estimator='precomputed', y_true, y_pred, ..., metric_kwargs=None)

I'll put together PR to see what this looks like.

I'm not really discussing the API yet, I'm only asking whether we're OK to support both options for new PRs.

(But regarding the API, I don't think 'precomputed' helps much: what do we do about X? I think we should just keep (y_pred) and (estimator, X) mutually exclusive, by properly erroring. Also what does it mean for an estimator to be precomputed?)

Or estimator='none', estimator='predictions', estimator='precomputed_predictions', and then X becomes y_pred or y_score. It's almost like how we handle precomputed distances with X in estimators.

Are we OK to support both scorer and metrics API for these two?

How are we going to support both options? With two functions?

I would have also liked:

CalibrationDisplay.from_estimator(...)
CalibrationDisplay.from_predictions(...)

which would be two methods.

Guillaume's suggestion of using tuples https://github.com/scikit-learn/scikit-learn/issues/15880#issuecomment-670590882 is one option. I think it would have been the best option if we had started from there from the beginning. But I'm afraid using tuples breaks consistency with the existing utilities.

plot_XYZ(estimator=None, X=None, y=None, y_pred=None) with mutual exclusion is another option, and it's the one I'm advocating for, for now.

I like CalibrationDisplay.from_estimator(...), but as Andy noted, we'd need to remove the plot_XYZ functions then. It might be worth considering.

I think we can move to tuples and deprecate the current behavior. (As long as we agree to use tuples)

So this seems like discussing the namespaces, right?
Whether we have one function and one constructor, or two classmethods, or two functions, it's exactly the same functionality and basically the same code.

@pzelasko @jhennrich how do you feel about having two classmethods or two functions? Or would you prefer a single function, which is a bit messy in python.

And if you prefer two functions or two classmethods, do you see any benefit despite discoverability? Discoverability might be enough of a reason to do classmethods though, I don't see a strong argument for having two functions.

Could we add the blocker label here? It seems that it is preventing progress on #18020 and #17443 (cc @cmarmo)

The blocker label is for release blockers (things that absolutely need to be fixed before a release), not for PR blockers

Ahh good to know.

@pzelasko @jhennrich how do you feel about having two classmethods or two functions? Or would you prefer a single function, which is a bit messy in python.

And if you prefer two functions or two classmethods, do you see any benefit despite discoverability? Discoverability might be enough of a reason to do classmethods though, I don't see a strong argument for having two functions.

I like the two-classmethods approach the most, especially the from_xxx pattern - sth like @thomasjpfan proposed:

CalibrationDisplay.from_estimator(...)
CalibrationDisplay.from_predictions(...)

Looks like there's no strong opposition to using 2 class methods, so let's do that. We'll need to:

  • Introduce the class methods for the currently existing plots:

    • ConfusionMatrixDisplay
    • PrecisionRecallDisplay
    • RocCurveDisplay
    • DetCurveDisplay
    • PartialDependenceDisplay. For this one, we don't want to introduce the from_predictions classmethod because it would not make sense, we only want from_estimator.
  • For all Display listed above, deprecate their corresponding plot_... function. We don't need to deprecate plot_det_curve because it hasn't been released yet, we can just remove it.

  • for new PRs like #17443 and #18020 we can implement the class methods right away instead of introducing a plot function.

This is a bit of work but I think we can get this done before 0.24 so that #17443 and #18020 can move forward already.

Any objection @thomasjpfan @jnothman @amueller @glemaitre ?

@jhennrich @pzelasko , would you be interested in submitting a PR to introduce the class methods in one of the Display objects?

Thanks for making the decision @NicolasHug! I'll get onto #17443 (after waiting for objections)

I have no objections.

No objection as well.

I will take care of the other classes then and advance my stalled PR.
@lucyleeow in case I did not do all of those and you are searching for some PRs, ping me :)

I'd love to contribute but I'm engaged in too many projects at this time. Thanks for listening to the suggestions!

Sounds good :)

Was this page helpful?
0 / 5 - 0 ratings