Skip to content

Commit f15a6de

Browse files
🐛 Fix nested dataclass comparison (#120)
Fix how `IsDataclass` handles nested dataclasses so they can be compared successfully. Previously something like the following would not work: ```python from dataclasses import dataclass from dirty_equals import IsDataclass @DataClass class Address: street: str zip_code: str @DataClass class Person: name: str address: Address person = Person( name='Alice', address=Address(street='123 Main St', zip_code='12345'), ) assert person == IsDataclass( name='Alice', address=IsDataclass(street='123 Main St', zip_code='12345') ) ``` This is a result of `IsDataclass` converting dataclasses to a dictionary for comparison using `dataclasses.asdict`, which _recursively_ converts dataclasses to plain dictionaries. Thus the inner `IsDataclass` fails, because the other side of the comparison is now a dictionary, not a dataclass. These changes fix this by shallowly converting the dataclass to a dictionary using [the approach recommended in the dataclasses documentation](https://docs.python.org/3/library/dataclasses.html#dataclasses.asdict): ```python {field.name: getattr(obj, field.name) for field in fields(obj)} ``` These changes also include a test for nested dataclass comparisons to ensure this works as expected and continues to work in the future.
1 parent 32f23d1 commit f15a6de

File tree

2 files changed

+24
-4
lines changed

2 files changed

+24
-4
lines changed

dirty_equals/_other.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import json
44
import re
5-
from dataclasses import asdict, is_dataclass
5+
from dataclasses import fields, is_dataclass
66
from enum import Enum
77
from functools import lru_cache
88
from ipaddress import IPv4Address, IPv4Network, IPv6Address, IPv6Network, ip_network
@@ -486,10 +486,11 @@ def _fields_check(self, other: Any) -> bool:
486486
"""
487487
Checks exactness of fields using [`IsDict`][dirty_equals.IsDict] with given settings.
488488
489-
Remark that if this method is called, then `other` is an instance of a dataclass, therefore we can call
490-
`dataclasses.asdict` to convert to a dict.
489+
Remark that if this method is called, then `other` is an instance of a dataclass, therefore we can use
490+
`dataclasses.fields` to get its fields. We use a shallow conversion to preserve nested dataclass instances.
491491
"""
492-
return asdict(other) == IsDict(self._repr_kwargs).settings(strict=self.strict, partial=self.partial)
492+
other_dict = {field.name: getattr(other, field.name) for field in fields(other)}
493+
return other_dict == IsDict(self._repr_kwargs).settings(strict=self.strict, partial=self.partial)
493494

494495

495496
class IsPartialDataclass(IsDataclass):

tests/test_other.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,21 @@ class Foo:
3939
foo = Foo(1, 2, 'c')
4040

4141

42+
@dataclass
43+
class Address:
44+
street: str
45+
zip_code: str
46+
47+
48+
@dataclass
49+
class Person:
50+
name: str
51+
address: Address
52+
53+
54+
person = Person(name='Alice', address=Address(street='123 Main St', zip_code='12345'))
55+
56+
4257
@pytest.mark.parametrize(
4358
'other,dirty',
4459
[
@@ -373,6 +388,10 @@ def test_is_dataclass_false(other, dirty):
373388
assert other != dirty
374389

375390

391+
def test_is_dataclass_nested():
392+
assert person == IsDataclass(name='Alice', address=IsDataclass(street='123 Main St', zip_code='12345'))
393+
394+
376395
@pytest.mark.parametrize(
377396
'other,dirty',
378397
[

0 commit comments

Comments
 (0)