-
-
Notifications
You must be signed in to change notification settings - Fork 26.1k
Adapting step_size and sag updates for sample_weight #31837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
See https://gist.github.com/antoinebaker/0fc40c94952a2da371bedb6caca53c10 : now both sag and saga converges towards the true minima, for the weighted and repeated datasets. The weighted version no longer has convergence issue. We illustrate the convergence on the blob dataset from We also illustrate with the dataset from our common test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please add an xfailing test that checks that both the _sag_fast.pyx
implementation and the (now fixed) Python implementation converge to the same solution with sample_weight
, both when alpha
is small-ish and larg-ish?
I am wondering if we shouldn't also adapt the content of https://gist.github.com/antoinebaker/0fc40c94952a2da371bedb6caca53c10 to turn it into a test.
I agree that once the Cython version is fixed, the existing common test can serve this purpose. However, it would only indirectly check the correctness of the reference Python version.
if (max_weight != 0 and max_change / max_weight <= tol) or ( | ||
max_weight == 0 and max_change == 0 | ||
): | ||
print(f"sag convergence after {epoch + 1} epochs") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Rather than printing things, I think the number of iteration before convergence should be reported in the results of this test helper function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could also rename n_iter
to max_iter
now that this function can check for early convergence.
Then we could update the existing tests to check that the effective number of iterations is always strictly lower than max_iter
whenever we call it with a strictly positive tol
value.
Reference Issues/PRs
Helper for PR #31675 to fix sample_weight support in SAG(A) solvers #31536.
What does this implement/fix? Explain your changes.
In test_sag.py several fixes are made to properly handle sample_weight.
get_step_size
: the formula Lipschitz smoothness constant are corrected to take into account sample_weightsag
: the updates for the weights and intercept are correctedsag
: the number of seen elements (which converges ton_samples
) is replaced by the weighted sum of seen elements (which converges tosample_weight.sum()
)A
true_weights
argument was added tosag
for visualisation purpose (plot the convergence towards the true minima). This is useful for the notebook below but can be removed in the final PR. Atol
argument was also added to use the same stopping criterion as_sag_fast.pyx
.cc @snath-xoc