Skip to content

FEA add temperature scaling to CalibratedClassifierCV #31068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 94 commits into
base: main
Choose a base branch
from

Conversation

virchan
Copy link
Member

@virchan virchan commented Mar 25, 2025

Reference Issues/PRs

Closes #28574

What does this implement/fix? Explain your changes.

This PR adds temperature scaling to scikit-learn's CalibratedClassifierCV:

Temperature scaling can be enabled by setting method = "temperature" in CalibratedClassifierCV:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.calibration import CalibratedClassifierCV
from sklearn.svm import LinearSVC

X, y = make_classification(random_state=42)

X_train, X_calib, y_train, y_calib = train_test_split(X, y, random_state=42)

clf = LinearSVC(random_state=42)
clf.fit(X_train, y_train)
cal_clf = CalibratedClassifierCV(clf, method="temperature").fit(X_train, y_train)

This method supports both binary and multi-class classification.

Any other comments?

Cc @adrinjalali, @lorentzenchr in advance.

Copy link

github-actions bot commented Mar 25, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 8610533. Link to the linter CI: here

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A follow-up to my comment on the Array API: I don't think we can support the Array API here, as scipy.optimize.minimize does not appear to support it.

If I missed anything, please let me know—I'd be happy to investigate further.

@virchan virchan marked this pull request as ready for review March 25, 2025 10:55
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Here is a first pass of feedback:

virchan added 4 commits March 27, 2025 18:14
…fier`.

Updated constructor of `_TemperatureScaling` class.
Updated `test_temperature_scaling` in `test_calibration.py`.
Added `__sklearn_tags__` to `_TemperatureScaling` class.
Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still working on addressing the feedback, but I also wanted to share some findings related to it and provide an update.

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I few computational things seem off.

virchan added 2 commits April 25, 2025 22:16
Update `minimize` in `_temperture_scaling` to `minimize.scalar`.
Update `test_calibration.py` to check the optimised inverse temperature is between 0.1 and 10.
Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some CI failures—I'll fix those shortly.

Also considering adding a verbose parameter to CalibratedClassifierCV to optionally display convergence info when optimising the inverse temperature beta.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI passed!

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close to the finish line.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the user guide and the docstrings in calibration.py. I also modified the test to check that the temperature parameter is close to 1 when the temperature scaler is fitted on the training set of the LogisticRegression classifier.

There are still some comments that need to be addressed, and I'll work on them later.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've refactored the part for checking reponse_method_name:

if len(classes) == 2 and predictions.shape[-1] == 1:
    response_method_name = _check_response_method(
        clf,
        ["decision_function", "predict_proba"],
    ).__name__
    if response_method_name == "predict_proba":
        predictions = np.hstack([1 - predictions, predictions])

I think this only needs to be applied in two places: : _fit_calibrator and _CalibratedClassifier.predict_proba. But please let me know if there's a better way to handle this.

I've also moved _temperature_scaling inside _TemperatureScaling.fit.

CI has passed, so it's ready for review!

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated _convert_to_logits to handle the conversion to a 2D array in _TemperatureScaling.fit and _TemperatureScaling.predict when the input is 1D.

_TemperatureScaling now has a new parameter, response_method_name ("decision_function" or "predict_proba"), which indicates whether it should fit or predict based on the output of decision_function or predict_proba, respectively. The default value is "decision_function".

In the CalibratedClassifierCV workflow, this value is computed in _fit_calibrator, then passed to _TemperatureScaling when the calibrator is initialised. If the input is 1D, _convert_to_logits will interpret it as either probabilities or decision values, accordingly.

This is to address an edge case where the output of predict_proba is 1D and was incorrectly converted to logits using [-x, x], which was first caught by the test_calibration_prefit function.

Previously, I attempted something similar in 2769eab, but I thought it was awkward for _TemperatureScaling to store response_method_name, so I didn't finalise it at the time. We'll see how this version goes.

@lorentzenchr
Copy link
Member

Sorry for the back and forth: Could you revert the last changes with the response_method_name c4ec0e8. It seems cleaner after all.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement temperature scaling for (multi-class) calibration
5 participants