Skip to content

Fix ClassifierChain error message for multiclass-multioutput targets #31797

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

JawadAliAI
Copy link

  • Add validation in ClassifierChain.fit() to detect multiclass-multioutput targets
  • Raise clear ValueError explaining that ClassifierChain is for multilabel classification
  • Suggest using MultiOutputClassifier as alternative
  • Add comprehensive tests for both error case and normal multilabel operation

Fixes #13339

Reference Issues/PRs

What does this implement/fix? Explain your changes.

Any other comments?

- Add validation in ClassifierChain.fit() to detect multiclass-multioutput targets
- Raise clear ValueError explaining that ClassifierChain is for multilabel classification
- Suggest using MultiOutputClassifier as alternative
- Add comprehensive tests for both error case and normal multilabel operation

Fixes scikit-learn#13339
Copy link

❌ Linting issues

This PR is introducing linting issues. Here's a summary of the issues. Note that you can avoid having linting issues by enabling pre-commit hooks. Instructions to enable them can be found here.

You can see the details of the linting issues under the lint job here


ruff check

ruff detected issues. Please run ruff check --fix --output-format=full locally, fix the remaining issues, and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.11.7.


sklearn/multioutput.py:1091:1: W293 [*] Blank line contains whitespace
     |
1089 |         # Validate input data
1090 |         X, Y = validate_data(self, X, Y, multi_output=True, accept_sparse=True)
1091 |         
     | ^^^^^^^^ W293
1092 |         # Check if we have multiclass-multioutput targets, which are not supported
1093 |         target_type = type_of_target(Y)
     |
     = help: Remove whitespace from blank line

sklearn/tests/test_multioutput.py:884:89: E501 Line too long (90 > 88)
    |
883 | def test_classifier_chain_multiclass_multioutput_error():
884 |     """Test that ClassifierChain raises clear error for multiclass-multioutput targets."""
    |                                                                                         ^^ E501
885 |     # Create multiclass-multioutput data (3 classes per output)
886 |     X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
    |

sklearn/tests/test_multioutput.py:889:65: W291 [*] Trailing whitespace
    |
887 |     y_multiclass_multioutput = np.array([
888 |         [0, 1],  # First output: class 0, Second output: class 1
889 |         [1, 2],  # First output: class 1, Second output: class 2  
    |                                                                 ^^ W291
890 |         [2, 0],  # First output: class 2, Second output: class 0
891 |         [0, 2],  # First output: class 0, Second output: class 2
    |
    = help: Remove trailing whitespace

sklearn/tests/test_multioutput.py:895:1: W293 [*] Blank line contains whitespace
    |
893 |         [2, 0],  # First output: class 2, Second output: class 0
894 |     ])
895 |     
    | ^^^^ W293
896 |     # This should raise a ValueError with clear message
897 |     chain = ClassifierChain(LogisticRegression(random_state=42))
    |
    = help: Remove whitespace from blank line

sklearn/tests/test_multioutput.py:898:1: W293 [*] Blank line contains whitespace
    |
896 |     # This should raise a ValueError with clear message
897 |     chain = ClassifierChain(LogisticRegression(random_state=42))
898 |     
    | ^^^^ W293
899 |     expected_msg = (
900 |         "ClassifierChain does not support multiclass-multioutput targets. "
    |
    = help: Remove whitespace from blank line

sklearn/tests/test_multioutput.py:905:1: W293 [*] Blank line contains whitespace
    |
903 |         "per output. Consider using MultiOutputClassifier instead."
904 |     )
905 |     
    | ^^^^ W293
906 |     with pytest.raises(ValueError, match=expected_msg):
907 |         chain.fit(X, y_multiclass_multioutput)
    |
    = help: Remove whitespace from blank line

sklearn/tests/test_multioutput.py:920:1: W293 [*] Blank line contains whitespace
    |
918 |         [1, 0],  # Has label 1, not label 2
919 |     ])
920 |     
    | ^^^^ W293
921 |     # This should work fine (no error)
922 |     chain = ClassifierChain(LogisticRegression(random_state=42))
    |
    = help: Remove whitespace from blank line

sklearn/tests/test_multioutput.py:924:1: W293 [*] Blank line contains whitespace
    |
922 |     chain = ClassifierChain(LogisticRegression(random_state=42))
923 |     chain.fit(X, y_multilabel)
924 |     
    | ^^^^ W293
925 |     # Basic functionality check
926 |     predictions = chain.predict(X)
    |
    = help: Remove whitespace from blank line

Found 8 errors.
[*] 7 fixable with the `--fix` option.

ruff format

ruff detected issues. Please run ruff format locally and push the changes. Here you can see the detected issues. Note that the installed ruff version is ruff=0.11.7.


--- sklearn/multioutput.py
+++ sklearn/multioutput.py
@@ -1088,7 +1088,7 @@
 
         # Validate input data
         X, Y = validate_data(self, X, Y, multi_output=True, accept_sparse=True)
-        
+
         # Check if we have multiclass-multioutput targets, which are not supported
         target_type = type_of_target(Y)
         if target_type == "multiclass-multioutput":

--- sklearn/tests/test_multioutput.py
+++ sklearn/tests/test_multioutput.py
@@ -884,25 +884,27 @@
     """Test that ClassifierChain raises clear error for multiclass-multioutput targets."""
     # Create multiclass-multioutput data (3 classes per output)
     X = np.array([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]])
-    y_multiclass_multioutput = np.array([
-        [0, 1],  # First output: class 0, Second output: class 1
-        [1, 2],  # First output: class 1, Second output: class 2  
-        [2, 0],  # First output: class 2, Second output: class 0
-        [0, 2],  # First output: class 0, Second output: class 2
-        [1, 1],  # First output: class 1, Second output: class 1
-        [2, 0],  # First output: class 2, Second output: class 0
-    ])
-    
+    y_multiclass_multioutput = np.array(
+        [
+            [0, 1],  # First output: class 0, Second output: class 1
+            [1, 2],  # First output: class 1, Second output: class 2
+            [2, 0],  # First output: class 2, Second output: class 0
+            [0, 2],  # First output: class 0, Second output: class 2
+            [1, 1],  # First output: class 1, Second output: class 1
+            [2, 0],  # First output: class 2, Second output: class 0
+        ]
+    )
+
     # This should raise a ValueError with clear message
     chain = ClassifierChain(LogisticRegression(random_state=42))
-    
+
     expected_msg = (
         "ClassifierChain does not support multiclass-multioutput targets. "
         "ClassifierChain is designed for multilabel classification where "
         "each target is binary \\(0 or 1\\). Your target has multiple classes "
         "per output. Consider using MultiOutputClassifier instead."
     )
-    
+
     with pytest.raises(ValueError, match=expected_msg):
         chain.fit(X, y_multiclass_multioutput)
 
@@ -911,17 +913,19 @@
     """Test that ClassifierChain still works correctly with multilabel data."""
     # Create proper multilabel data (binary values only)
     X = np.array([[1, 2], [3, 4], [5, 6], [7, 8]])
-    y_multilabel = np.array([
-        [0, 1],  # Not label 1, has label 2
-        [1, 1],  # Has label 1, has label 2
-        [0, 0],  # No labels
-        [1, 0],  # Has label 1, not label 2
-    ])
-    
+    y_multilabel = np.array(
+        [
+            [0, 1],  # Not label 1, has label 2
+            [1, 1],  # Has label 1, has label 2
+            [0, 0],  # No labels
+            [1, 0],  # Has label 1, not label 2
+        ]
+    )
+
     # This should work fine (no error)
     chain = ClassifierChain(LogisticRegression(random_state=42))
     chain.fit(X, y_multilabel)
-    
+
     # Basic functionality check
     predictions = chain.predict(X)
     assert predictions.shape == y_multilabel.shape

2 files would be reformatted, 924 files already formatted

Generated for commit: cc565d3. Link to the linter CI: here

@betatim
Copy link
Member

betatim commented Jul 22, 2025

Thanks for creating the PR. It looks like you used some tooling to create it, not a bad thing by itself. However standard clean up comments apply, please clean up the diff a bit (e.g. comments that just restate what is happening on the next line), shorten the exception message and reformat your top PR comment to use the template. In particular explain why you made the changes you made (it is easy to see what you changed, it is less easy to see why you think these are the right changes to make or what tradeoffs you considered)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Bad error messages in ClassifierChain on multioutput multiclass
2 participants