Skip to content

Commit 22f9ead

Browse files
alexfiklinducer
authored andcommitted
feat: improve dataclass container
1 parent dee0ca4 commit 22f9ead

File tree

2 files changed

+53
-10
lines changed

2 files changed

+53
-10
lines changed

arraycontext/container/dataclass.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ def dataclass_array_container(cls: type) -> type:
5959
array containers, even if they wrap one.
6060
"""
6161

62+
from types import GenericAlias, UnionType
63+
6264
assert is_dataclass(cls)
6365

6466
def is_array_field(f: Field) -> bool:
@@ -75,7 +77,8 @@ def is_array_field(f: Field) -> bool:
7577
# This is not set in stone, but mostly driven by current usage!
7678

7779
origin = get_origin(f.type)
78-
if origin is Union:
80+
# NOTE: `UnionType` is returned when using `Type1 | Type2`
81+
if origin in (Union, UnionType):
7982
if all(is_array_type(arg) for arg in get_args(f.type)):
8083
return True
8184
else:
@@ -94,13 +97,14 @@ def is_array_field(f: Field) -> bool:
9497
f"Field with 'init=False' not allowed: '{f.name}'")
9598

9699
# NOTE:
100+
# * `GenericAlias` catches typed `list`, `tuple`, etc.
97101
# * `_BaseGenericAlias` catches `List`, `Tuple`, etc.
98102
# * `_SpecialForm` catches `Any`, `Literal`, etc.
99103
from typing import ( # type: ignore[attr-defined]
100104
_BaseGenericAlias,
101105
_SpecialForm,
102106
)
103-
if isinstance(f.type, _BaseGenericAlias | _SpecialForm):
107+
if isinstance(f.type, GenericAlias | _BaseGenericAlias | _SpecialForm):
104108
# NOTE: anything except a Union is not allowed
105109
raise TypeError(
106110
f"Typing annotation not supported on field '{f.name}': "

test/test_utils.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ def test_pt_actx_key_stringification_uniqueness():
4949

5050
def test_dataclass_array_container() -> None:
5151
from dataclasses import dataclass, field
52-
from typing import Optional
52+
from typing import Optional, Tuple # noqa: UP035
5353

54-
from arraycontext import dataclass_array_container
54+
from arraycontext import Array, dataclass_array_container
5555

5656
# {{{ string fields
5757

@@ -60,7 +60,7 @@ class ArrayContainerWithStringTypes:
6060
x: np.ndarray
6161
y: "np.ndarray"
6262

63-
with pytest.raises(TypeError):
63+
with pytest.raises(TypeError, match="String annotation on field 'y'"):
6464
# NOTE: cannot have string annotations in container
6565
dataclass_array_container(ArrayContainerWithStringTypes)
6666

@@ -73,12 +73,32 @@ class ArrayContainerWithOptional:
7373
x: np.ndarray
7474
y: Optional[np.ndarray]
7575

76-
with pytest.raises(TypeError):
76+
with pytest.raises(TypeError, match="Field 'y' union contains non-array"):
7777
# NOTE: cannot have wrapped annotations (here by `Optional`)
7878
dataclass_array_container(ArrayContainerWithOptional)
7979

8080
# }}}
8181

82+
# {{{ type annotations
83+
84+
@dataclass
85+
class ArrayContainerWithTuple:
86+
x: Array
87+
y: Tuple[Array, Array]
88+
89+
with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"):
90+
dataclass_array_container(ArrayContainerWithTuple)
91+
92+
@dataclass
93+
class ArrayContainerWithTupleAlt:
94+
x: Array
95+
y: tuple[Array, Array]
96+
97+
with pytest.raises(TypeError, match="Typing annotation not supported on field 'y'"):
98+
dataclass_array_container(ArrayContainerWithTupleAlt)
99+
100+
# }}}
101+
82102
# {{{ field(init=False)
83103

84104
@dataclass
@@ -87,16 +107,14 @@ class ArrayContainerWithInitFalse:
87107
y: np.ndarray = field(default_factory=lambda: np.zeros(42),
88108
init=False, repr=False)
89109

90-
with pytest.raises(ValueError):
110+
with pytest.raises(ValueError, match="Field with 'init=False' not allowed"):
91111
# NOTE: init=False fields are not allowed
92112
dataclass_array_container(ArrayContainerWithInitFalse)
93113

94114
# }}}
95115

96116
# {{{ device arrays
97117

98-
from arraycontext import Array
99-
100118
@dataclass
101119
class ArrayContainerWithArray:
102120
x: Array
@@ -126,6 +144,13 @@ class ArrayContainerWithUnion:
126144

127145
dataclass_array_container(ArrayContainerWithUnion)
128146

147+
@dataclass
148+
class ArrayContainerWithUnionAlt:
149+
x: np.ndarray
150+
y: np.ndarray | Array
151+
152+
dataclass_array_container(ArrayContainerWithUnionAlt)
153+
129154
# }}}
130155

131156
# {{{ non-container union
@@ -135,12 +160,26 @@ class ArrayContainerWithWrongUnion:
135160
x: np.ndarray
136161
y: Union[np.ndarray, float]
137162

138-
with pytest.raises(TypeError):
163+
with pytest.raises(TypeError, match="Field 'y' union contains non-array container"):
139164
# NOTE: float is not an ArrayContainer, so y should fail
140165
dataclass_array_container(ArrayContainerWithWrongUnion)
141166

142167
# }}}
143168

169+
# {{{ optional union
170+
171+
@dataclass
172+
class ArrayContainerWithOptionalUnion:
173+
x: np.ndarray
174+
y: np.ndarray | None
175+
176+
with pytest.raises(TypeError, match="Field 'y' union contains non-array container"):
177+
# NOTE: None is not an ArrayContainer, so y should fail
178+
dataclass_array_container(ArrayContainerWithWrongUnion)
179+
180+
# }}}
181+
182+
144183
# }}}
145184

146185

0 commit comments

Comments
 (0)