Skip to content

Adapting step_size and sag updates for sample_weight #31837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

antoinebaker
Copy link
Contributor

@antoinebaker antoinebaker commented Jul 25, 2025

Reference Issues/PRs

Helper for PR #31675 to fix sample_weight support in SAG(A) solvers #31536.

What does this implement/fix? Explain your changes.

In test_sag.py several fixes are made to properly handle sample_weight.

  1. get_step_size: the formula Lipschitz smoothness constant are corrected to take into account sample_weight
  2. sag: the updates for the weights and intercept are corrected
  3. sag: the number of seen elements (which converges to n_samples) is replaced by the weighted sum of seen elements (which converges to sample_weight.sum())

A true_weights argument was added to sag for visualisation purpose (plot the convergence towards the true minima). This is useful for the notebook below but can be removed in the final PR. A tol argument was also added to use the same stopping criterion as _sag_fast.pyx.

cc @snath-xoc

Copy link

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: d117189. Link to the linter CI: here

@antoinebaker
Copy link
Contributor Author

See https://gist.github.com/antoinebaker/0fc40c94952a2da371bedb6caca53c10 : now both sag and saga converges towards the true minima, for the weighted and repeated datasets. The weighted version no longer has convergence issue.

We illustrate the convergence on the blob dataset from test_classifier_matching.

We also illustrate with the dataset from our common test check_sample_weight_equivalence, which should pass after a similar fix in _sag_fast,pyx. Note that the sag and saga solvers, while being stochastic, actually yield a deterministic output when they have converged (the true minima), and it's enough to test them with the deterministic repeated/weighted equivalence test.

Copy link
Member

@ogrisel ogrisel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please add an xfailing test that checks that both the _sag_fast.pyx implementation and the (now fixed) Python implementation converge to the same solution with sample_weight, both when alpha is small-ish and larg-ish?

I am wondering if we shouldn't also adapt the content of https://gist.github.com/antoinebaker/0fc40c94952a2da371bedb6caca53c10 to turn it into a test.

I agree that once the Cython version is fixed, the existing common test can serve this purpose. However, it would only indirectly check the correctness of the reference Python version.

if (max_weight != 0 and max_change / max_weight <= tol) or (
max_weight == 0 and max_change == 0
):
print(f"sag convergence after {epoch + 1} epochs")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rather than printing things, I think the number of iteration before convergence should be reported in the results of this test helper function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could also rename n_iter to max_iter now that this function can check for early convergence.

Then we could update the existing tests to check that the effective number of iterations is always strictly lower than max_iter whenever we call it with a strictly positive tol value.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants