Skip to content

MNT Add _check_sample_weights to classification metrics #31701

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jul 27, 2025

Conversation

lucyleeow
Copy link
Member

Reference Issues/PRs

Follow up to #30886

What does this implement/fix? Explain your changes.

  • perform check_consistent_length on y_true, y_prob and sample_weights in _check_targets - this avoids the second check_consistent_length, which means that all length checks occur at the start and you know who is raising errors (note this is not about avoiding the double checking, as they are not expensive checks)
  • adds _check_sample_weight to _check_targets, _validate_multiclass_probabilistic_prediction and _validate_binary_probabilistic_prediction - I am not 100% sure on this. Currently this check is only being done in d2_log_loss_score. This check does the following:
    • ensures all values are finite
    • ensure not complex data (i.e. converts ComplexWarning to error, though not sure if this warning is only raised for numpy arrays or other array API arrays)
    • ensure array.ndim not greater than 3

This seems like reasonable checks to have. The only potential downside is that these checks would take a bit more time, but I don't think this is really a problem.

cc @ogrisel

Any other comments?

@lucyleeow lucyleeow marked this pull request as draft July 4, 2025 12:12
Copy link

github-actions bot commented Jul 4, 2025

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 4f2759e. Link to the linter CI: here

@@ -596,7 +596,7 @@ def test_multilabel_confusion_matrix_errors():
# Bad sample_weight
with pytest.raises(ValueError, match="inconsistent numbers of samples"):
multilabel_confusion_matrix(y_true, y_pred, sample_weight=[1, 2])
with pytest.raises(ValueError, match="should be a 1d array"):
with pytest.raises(ValueError, match="Sample weights must be 1D array or scalar"):
Copy link
Member Author

@lucyleeow lucyleeow Jul 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that this is because this error is now being raised by _check_sample_weight:

if sample_weight.ndim != 1:
raise ValueError("Sample weights must be 1D array or scalar")

Instead of column_or_1d, which is called after _check_sample_weight:

sample_weight = column_or_1d(sample_weight, device=device_)

This is actually a fix because the old error message was: "y should be a 1d array, got an array of shape (3, 3) instead." - which is mis-leading as it was actually sample_weight

For reference here is some discussion in the original PR adding _check_sample_weight: https://github.com/scikit-learn/scikit-learn/pull/14307/files#r302938269

@lucyleeow
Copy link
Member Author

I'm not 100% about the addition of _check_sample_weights so won't add tests until we decide we are happy about this change. Thanks!

@lucyleeow lucyleeow marked this pull request as ready for review July 7, 2025 04:12
Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding sample weight checking to _check_targets looks good because it allows to remove a redundant check_consistent_length call. The name _check_targets feels a bit off now though.

The check_sample_weight added in the PR are only used for the checks but not for the conversions of generations that check_sample_weight is able to do. It would be interesting to check if it's enough. Passing an int as sample_weight for instance. We might have to make _check_targets return the validated sample weights like it does for y_true and y_pred.

Comment on lines 491 to 492
else:
sample_weight = np.asarray(sample_weight)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could replace these lines by a call to _check_sample_weight, no ?

@lucyleeow
Copy link
Member Author

lucyleeow commented Jul 7, 2025

The name _check_targets feels a bit off now though.

That is a very good point 🤔 I don't know (we kept _check_reg_targets in #30886, but maybe we could consider changing the name....?)

The check_sample_weight added in the PR are only used for the checks but not for the conversions of generations that check_sample_weight is able to do.

🤦 I meant to return sample_weight like I did in #30886, sorry that was a brain fart. Thanks for your patience.

Let me check what can be deleted with that.

@jeremiedbb
Copy link
Member

we kept _check_reg_targets in #30886,

Let's keep the name as is then :)


xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight)

if sample_weight is None:
weight_average = 1.0
else:
sample_weight = xp.asarray(sample_weight, device=device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Although check_array inside _check_sample_weights does not specify device, because we use xp, _, device = get_namespace_and_device(y_true, y_pred, sample_weight) I think sample weight should be on the right device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought that if the 3 arrays are not on the same device an error is raised, so indeed if we get there, they should be on the same device.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggested casting sample_weight to the same xp array type and device, because I kept getting a DeprecationWarning when running the CUDA CI locally while reviewing #30838 😅. (See #30838 (comment) for reference.)

The warning might be a false positive on my end or already outdated, but I think it's safer to include the casting line to ensure correctness of the namespace and device.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm but since we do get_namespace_and_device(y_true, y_pred, sample_weight) above, wouldn't it error if all 3 arrays are not in the same namespace / device?

@lucyleeow
Copy link
Member Author

lucyleeow commented Jul 9, 2025

I've hit a snag. _check_sample_weight forces sample_weight to be a float, even if you specifically pass an int dtype:

if dtype is not None and dtype not in float_dtypes:
dtype = max_float_type

This was not a problem for regression targets because we mostly wanted it to be a float (i.e., we use _check_reg_targets_with_floating_dtype) and when_check_reg_targets_with_floating_dtype wasn't explicitly used, we were passing sample_weight to _averaged_weighted_percentile where we upcast sample_weight anyway.

For classification metrics, it is fine for sample_weights to be an int or bool (and indeed we check for this in test_confusion_matrix_dtype).

Looking at the uses of _check_sample_weight where we have specifically specified a dtype (see search), it's typically always been a float (because we did an array check/validate previously, and ensured y or X is float32 or float64).

I am wondering if we could remove the dtype not in float_dtypes part...?

It is also confusing that if you specifically specify an int dtype, _check_sample_weight will still upcast to float64.

The safer option would be to add another parameter to avoid upcasting to float when you specify a non-float dytpe in _check_sample_weight.

WDYT?

@lucyleeow
Copy link
Member Author

Actually I see two cases where I think it is possible that it is passed an int dtype:

sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

and

sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype)

@jeremiedbb
Copy link
Member

I am wondering if we could remove the dtype not in float_dtypes part...?

It's not straightforward. IIRC this upcast was added to avoid converting float sample weights to integers, possibly loosing precision in the process.

The safer option would be to add another parameter to avoid upcasting to float when you specify a non-float dytpe in _check_sample_weight.

+1 to try that approach

Comment on lines 671 to 672
if sample_weight is not None:
sample_weight = column_or_1d(sample_weight, device=device_)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We keep that for the device, right ?
It feels a bit weird. Do you think check_targets should be responsible for this ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept it for the reshape but now realise since sample_weight has to be 1D, the reshape would nothing 🤦

I would say with #28668, that part would be handled by a next function that does conversions...

I would delete this, I don't think it was designed for handling device conversion. _asarray_with_order uses xp.asarray, for which the Array API definition is:

object to be converted to an array. May be a Python scalar, a (possibly nested) sequence of Python scalars, or an object supporting the Python buffer protocol.

With this definition, you couldn't pass a e.g., CuPy array. And for device conversion of e.g., torch GPU and CPU arrays, it would not work either.

@jeremiedbb
Copy link
Member

Something that is not clear to me is when sample_weights is not on the same device as y_true/y_pred, do we want to raise an error, or move it to the device. It seems that there are cases where we do one and cases where we do the other, but maybe I missed something. Maybe @lesteve has more insights ?

@lesteve
Copy link
Member

lesteve commented Jul 21, 2025

I don't know off the top of my head but it feels like it could be convenient to be able to move sample_weight to the device? Honestly @lucyleeow is probably way more qualified than me on this, since she was closely involved in the discussion on y_true follows y_pred.

I guess it's possible to imagine a Pipeline where sample_weight needs to be on the CPU for a given step and on the GPU for another one.

@lucyleeow
Copy link
Member Author

I think we should allow device movement, but I am not sure _check_sample_weights should be the one to do it. We have 'everything follows X' and 'everything follows y_pred', which means there may be other input arrays that need to be on the same device as X/y_pred. May be best to leave to a conversion function that handles all input arrays that need to be on the same device/namespace

Copy link
Member

@jeremiedbb jeremiedbb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The device movements can be discussed in a separate issue imo. This PR LGTM. Thanks @lucyleeow

@jeremiedbb jeremiedbb added the Waiting for Second Reviewer First reviewer is done, need a second one! label Jul 22, 2025
Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @lucyleeow.

Copy link
Contributor

@OmarManzoor OmarManzoor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks @lucyleeow

@OmarManzoor OmarManzoor merged commit 27e5256 into scikit-learn:main Jul 27, 2025
36 checks passed
@lucyleeow lucyleeow deleted the class_checks branch July 27, 2025 08:06
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module:metrics Waiting for Second Reviewer First reviewer is done, need a second one!
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants