Skip to content

FEA add temperature scaling to CalibratedClassifierCV #31068

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 79 commits into
base: main
Choose a base branch
from

Conversation

virchan
Copy link
Member

@virchan virchan commented Mar 25, 2025

Reference Issues/PRs

Closes #28574

What does this implement/fix? Explain your changes.

This PR adds temperature scaling to scikit-learn's CalibratedClassifierCV:

Temperature scaling can be enabled by setting method = "temperature" in CalibratedClassifierCV:

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.calibration import CalibratedClassifierCV
from sklearn.svm import LinearSVC

X, y = make_classification(random_state=42)

X_train, X_calib, y_train, y_calib = train_test_split(X, y, random_state=42)

clf = LinearSVC(random_state=42)
clf.fit(X_train, y_train)
cal_clf = CalibratedClassifierCV(clf, method="temperature").fit(X_train, y_train)

This method supports both binary and multi-class classification.

Any other comments?

Cc @adrinjalali, @lorentzenchr in advance.

Copy link

github-actions bot commented Mar 25, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: ad2a786. Link to the linter CI: here

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A follow-up to my comment on the Array API: I don't think we can support the Array API here, as scipy.optimize.minimize does not appear to support it.

If I missed anything, please let me know—I'd be happy to investigate further.

@virchan virchan marked this pull request as ready for review March 25, 2025 10:55
Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR. Here is a first pass of feedback:

virchan added 4 commits March 27, 2025 18:14
…fier`.

Updated constructor of `_TemperatureScaling` class.
Updated `test_temperature_scaling` in `test_calibration.py`.
Added `__sklearn_tags__` to `_TemperatureScaling` class.
Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still working on addressing the feedback, but I also wanted to share some findings related to it and provide an update.

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I few computational things seem off.

virchan added 2 commits April 25, 2025 22:16
Update `minimize` in `_temperture_scaling` to `minimize.scalar`.
Update `test_calibration.py` to check the optimised inverse temperature is between 0.1 and 10.
Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are some CI failures—I'll fix those shortly.

Also considering adding a verbose parameter to CalibratedClassifierCV to optionally display convergence info when optimising the inverse temperature beta.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI passed!

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thank you @virchan

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The CI fails when checking that the ROC AUCs are equal up to 7 decimal places. I'll fix it later.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CI passed!

Copy link
Member

@lorentzenchr lorentzenchr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Close to the finish line.

Comment on lines +1062 to +1064
l = halfmulti_loss.loss(y_true=labels, raw_prediction=raw_prediction)

return l.sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
l = halfmulti_loss.loss(y_true=labels, raw_prediction=raw_prediction)
return l.sum()
return halfmulti_loss(y_true=labels, raw_prediction=raw_prediction)

The call method returns the (weighted) average loss.

return logits


def _temperature_scaling(predictions, labels, sample_weight=None):
Copy link
Member

@lorentzenchr lorentzenchr Jul 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only called in _TemperatureScaling.fit. So why not place the code there instead of this separate function?

@@ -359,7 +377,7 @@ def test_calibration_prefit(csr_container):
)


@pytest.mark.parametrize("method", ["sigmoid", "isotonic"])
@pytest.mark.parametrize("method", ["sigmoid", "isotonic", "temperature"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@pytest.mark.parametrize("method", ["sigmoid", "isotonic", "temperature"])
@pytest.mark.parametrize(
["method", "calibrator"],
[
("sigmoid", _SigmoidCalibration()),
("isotonic", IsotonicRegression(out_of_bounds="clip")),
("temperature", _TemperatureScaling())
]
)

and then remove the if-else inside.

Copy link
Member Author

@virchan virchan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated the user guide and the docstrings in calibration.py. I also modified the test to check that the temperature parameter is close to 1 when the temperature scaler is fitted on the training set of the LogisticRegression classifier.

There are still some comments that need to be addressed, and I'll work on them later.

Comment on lines +492 to +494
y_scores_train = clf.predict_proba(X_train)
ts = _TemperatureScaling().fit(y_scores_train, y_train)
assert_allclose(ts.beta_, 1.0, atol=1.1e-7, rtol=0)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For logistic regression, we could even check that temperature is 1 if fitted on the training set.

Let me know if atol or rtol needs to be adjusted here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement temperature scaling for (multi-class) calibration
5 participants