-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
MNT Add _check_sample_weights
to classification metrics
#31701
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@@ -596,7 +596,7 @@ def test_multilabel_confusion_matrix_errors(): | |||
# Bad sample_weight | |||
with pytest.raises(ValueError, match="inconsistent numbers of samples"): | |||
multilabel_confusion_matrix(y_true, y_pred, sample_weight=[1, 2]) | |||
with pytest.raises(ValueError, match="should be a 1d array"): | |||
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this is because this error is now being raised by _check_sample_weight
:
scikit-learn/sklearn/utils/validation.py
Lines 2207 to 2208 in 9489ee6
if sample_weight.ndim != 1: | |
raise ValueError("Sample weights must be 1D array or scalar") |
Instead of column_or_1d
, which is called after _check_sample_weight
:
sample_weight = column_or_1d(sample_weight, device=device_) |
This is actually a fix because the old error message was: "y should be a 1d array, got an array of shape (3, 3) instead." - which is mis-leading as it was actually sample_weight
For reference here is some discussion in the original PR adding _check_sample_weight
: https://github.com/scikit-learn/scikit-learn/pull/14307/files#r302938269
I'm not 100% about the addition of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding sample weight checking to _check_targets
looks good because it allows to remove a redundant check_consistent_length
call. The name _check_targets
feels a bit off now though.
The check_sample_weight
added in the PR are only used for the checks but not for the conversions of generations that check_sample_weight
is able to do. It would be interesting to check if it's enough. Passing an int as sample_weight
for instance. We might have to make _check_targets
return the validated sample weights like it does for y_true
and y_pred
.
sklearn/metrics/_classification.py
Outdated
else: | ||
sample_weight = np.asarray(sample_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could replace these lines by a call to _check_sample_weight
, no ?
That is a very good point 🤔 I don't know (we kept
🤦 I meant to return Let me check what can be deleted with that. |
Let's keep the name as is then :) |
|
||
xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight) | ||
|
||
if sample_weight is None: | ||
weight_average = 1.0 | ||
else: | ||
sample_weight = xp.asarray(sample_weight, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although check_array
inside _check_sample_weights
does not specify device
, because we use xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)
I think sample weight should be on the right device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that if the 3 arrays are not on the same device an error is raised, so indeed if we get there, they should be on the same device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggested casting sample_weight
to the same xp
array type and device
, because I kept getting a DeprecationWarning
when running the CUDA CI locally while reviewing #30838 😅. (See #30838 (comment) for reference.)
The warning might be a false positive on my end or already outdated, but I think it's safer to include the casting line to ensure correctness of the namespace and device.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm but since we do get_namespace_and_device(y_true, y_pred, sample_weight)
above, wouldn't it error if all 3 arrays are not in the same namespace / device?
I've hit a snag. scikit-learn/sklearn/utils/validation.py Lines 2188 to 2189 in 0872e9a
This was not a problem for regression targets because we mostly wanted it to be a float (i.e., we use For classification metrics, it is fine for Looking at the uses of I am wondering if we could remove the It is also confusing that if you specifically specify an int dtype, The safer option would be to add another parameter to avoid upcasting to float when you specify a non-float dytpe in WDYT? |
Actually I see two cases where I think it is possible that it is passed an int dtype:
and
|
It's not straightforward. IIRC this upcast was added to avoid converting float sample weights to integers, possibly loosing precision in the process.
+1 to try that approach |
sklearn/metrics/_classification.py
Outdated
if sample_weight is not None: | ||
sample_weight = column_or_1d(sample_weight, device=device_) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We keep that for the device, right ?
It feels a bit weird. Do you think check_targets
should be responsible for this ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I kept it for the reshape
but now realise since sample_weight
has to be 1D, the reshape
would nothing 🤦
I would say with #28668, that part would be handled by a next function that does conversions...
I would delete this, I don't think it was designed for handling device conversion. _asarray_with_order
uses xp.asarray
, for which the Array API definition is:
object to be converted to an array. May be a Python scalar, a (possibly nested) sequence of Python scalars, or an object supporting the Python buffer protocol.
With this definition, you couldn't pass a e.g., CuPy array. And for device conversion of e.g., torch GPU and CPU arrays, it would not work either.
Something that is not clear to me is when |
I don't know off the top of my head but it feels like it could be convenient to be able to move I guess it's possible to imagine a Pipeline where |
I think we should allow device movement, but I am not sure |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The device movements can be discussed in a separate issue imo. This PR LGTM. Thanks @lucyleeow
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @lucyleeow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks @lucyleeow
Reference Issues/PRs
Follow up to #30886
What does this implement/fix? Explain your changes.
check_consistent_length
ony_true
,y_prob
andsample_weights
in_check_targets
- this avoids the secondcheck_consistent_length
, which means that all length checks occur at the start and you know who is raising errors (note this is not about avoiding the double checking, as they are not expensive checks)_check_sample_weight
to_check_targets
,_validate_multiclass_probabilistic_prediction
and_validate_binary_probabilistic_prediction
- I am not 100% sure on this. Currently this check is only being done ind2_log_loss_score
. This check does the following:ComplexWarning
to error, though not sure if this warning is only raised for numpy arrays or other array API arrays)array.ndim
not greater than 3This seems like reasonable checks to have. The only potential downside is that these checks would take a bit more time, but I don't think this is really a problem.
cc @ogrisel
Any other comments?