-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
FEA add temperature scaling to CalibratedClassifierCV
#31068
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FEA add temperature scaling to CalibratedClassifierCV
#31068
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A follow-up to my comment on the Array API: I don't think we can support the Array API here, as scipy.optimize.minimize
does not appear to support it.
If I missed anything, please let me know—I'd be happy to investigate further.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR. Here is a first pass of feedback:
…enting_temperature_scaling
…fier`. Updated constructor of `_TemperatureScaling` class. Updated `test_temperature_scaling` in `test_calibration.py`. Added `__sklearn_tags__` to `_TemperatureScaling` class.
…enting_temperature_scaling
…enting_temperature_scaling
…enting_temperature_scaling
…enting_temperature_scaling
…enting_temperature_scaling
…Updated doc-strings of temperature scaling in `calibration.py`. Updated formatting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still working on addressing the feedback, but I also wanted to share some findings related to it and provide an update.
…enting_temperature_scaling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I few computational things seem off.
…enting_temperature_scaling
Update `minimize` in `_temperture_scaling` to `minimize.scalar`. Update `test_calibration.py` to check the optimised inverse temperature is between 0.1 and 10.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some CI failures—I'll fix those shortly.
Also considering adding a verbose
parameter to CalibratedClassifierCV
to optionally display convergence info when optimising the inverse temperature beta
.
…enting_temperature_scaling
…id `method` parameter.
…ce in all `ensemble` cases.
…enting_temperature_scaling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI passed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thank you @virchan
…enting_temperature_scaling
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The CI fails when checking that the ROC AUCs are equal up to 7 decimal places. I'll fix it later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CI passed!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Close to the finish line.
l = halfmulti_loss.loss(y_true=labels, raw_prediction=raw_prediction) | ||
|
||
return l.sum() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
l = halfmulti_loss.loss(y_true=labels, raw_prediction=raw_prediction) | |
return l.sum() | |
return halfmulti_loss(y_true=labels, raw_prediction=raw_prediction) |
The call method returns the (weighted) average loss.
return logits | ||
|
||
|
||
def _temperature_scaling(predictions, labels, sample_weight=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is only called in _TemperatureScaling.fit
. So why not place the code there instead of this separate function?
@@ -359,7 +377,7 @@ def test_calibration_prefit(csr_container): | |||
) | |||
|
|||
|
|||
@pytest.mark.parametrize("method", ["sigmoid", "isotonic"]) | |||
@pytest.mark.parametrize("method", ["sigmoid", "isotonic", "temperature"]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@pytest.mark.parametrize("method", ["sigmoid", "isotonic", "temperature"]) | |
@pytest.mark.parametrize( | |
["method", "calibrator"], | |
[ | |
("sigmoid", _SigmoidCalibration()), | |
("isotonic", IsotonicRegression(out_of_bounds="clip")), | |
("temperature", _TemperatureScaling()) | |
] | |
) |
and then remove the if-else inside.
…eter is close to 1.0 when fitted on the training set of LogisticRegression.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I updated the user guide and the docstrings in calibration.py
. I also modified the test to check that the temperature parameter is close to 1 when the temperature scaler is fitted on the training set of the LogisticRegression
classifier.
There are still some comments that need to be addressed, and I'll work on them later.
y_scores_train = clf.predict_proba(X_train) | ||
ts = _TemperatureScaling().fit(y_scores_train, y_train) | ||
assert_allclose(ts.beta_, 1.0, atol=1.1e-7, rtol=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For logistic regression, we could even check that temperature is 1 if fitted on the training set.
Let me know if atol or rtol needs to be adjusted here.
Reference Issues/PRs
Closes #28574
What does this implement/fix? Explain your changes.
This PR adds temperature scaling to scikit-learn's
CalibratedClassifierCV
:Temperature scaling can be enabled by setting
method = "temperature"
inCalibratedClassifierCV
:This method supports both binary and multi-class classification.
Any other comments?
Cc @adrinjalali, @lorentzenchr in advance.