26
26
import jax
27
27
import numpy as np
28
28
29
+ _AXIS_TYPE_MAP = {str (val ): val for val in jax .sharding .AxisType }
29
30
PartitionSpecElement = Union [None , str , Tuple [str , ...]]
30
31
31
32
_PARTITION_SPEC = 'partition_spec'
32
33
_SHARDING = '_sharding'
33
34
_SHARDING_TYPE = 'sharding_type'
34
35
_DEVICE_STR = 'device_str'
35
36
_MESH_AXES = 'axis_names'
37
+ _MESH_AXIS_TYPES = 'axis_types'
36
38
_MESH_SHAPE = 'shape'
37
39
_DEVICES_SHAPE = 'shape'
38
40
_DEVICE_MESH = 'device_mesh'
@@ -181,6 +183,7 @@ class NamedShardingMetadata(ShardingMetadata):
181
183
partition_spec : Tuple [
182
184
PartitionSpecElement , ...
183
185
] # Each element is either ``None``, a string, or a tuple of strings.
186
+ axis_types : Optional [Tuple [jax .sharding .AxisType , ...]] = None
184
187
185
188
# Optional device mesh. If it's None, use jax.devices(),
186
189
# otherwise, the stored device_mesh will be used to recreate NamedSharding.
@@ -193,6 +196,7 @@ def from_jax_sharding(
193
196
return cls (
194
197
shape = np .array (list (jax_sharding .mesh .shape .values ())),
195
198
axis_names = list (jax_sharding .mesh .axis_names ),
199
+ axis_types = tuple (jax_sharding .mesh .axis_types ),
196
200
partition_spec = tuple (jax_sharding .spec ),
197
201
device_mesh = DeviceMetadataMesh .from_jax_mesh (jax_sharding .mesh ),
198
202
)
@@ -207,6 +211,7 @@ def to_jax_sharding(self) -> jax.sharding.NamedSharding:
207
211
jax .sharding .Mesh (
208
212
np .asarray (mesh_devices ).reshape (self .shape ),
209
213
axis_names = self .axis_names ,
214
+ axis_types = self .axis_types ,
210
215
),
211
216
spec = jax .sharding .PartitionSpec (* self .partition_spec ),
212
217
)
@@ -222,6 +227,9 @@ def from_deserialized_dict(
222
227
):
223
228
shape = np .array (deserialized_dict [_MESH_SHAPE ])
224
229
axis_names = list (deserialized_dict [_MESH_AXES ])
230
+ axis_types = None
231
+ if axis_types_raw := deserialized_dict .get (_MESH_AXIS_TYPES ):
232
+ axis_types = tuple ([_AXIS_TYPE_MAP [s ] for s in axis_types_raw ])
225
233
partition_spec = tuple (deserialized_dict [_PARTITION_SPEC ])
226
234
if device_mesh_dic := deserialized_dict .get (_DEVICE_MESH ):
227
235
device_mesh = DeviceMetadataMesh .from_dict (device_mesh_dic )
@@ -231,6 +239,7 @@ def from_deserialized_dict(
231
239
return cls (
232
240
shape = shape ,
233
241
axis_names = axis_names ,
242
+ axis_types = axis_types ,
234
243
partition_spec = partition_spec ,
235
244
device_mesh = device_mesh ,
236
245
)
@@ -244,6 +253,8 @@ def to_serialized_string(self) -> str:
244
253
sharding_data [_SHARDING_TYPE ] = ShardingTypes .NAMED_SHARDING .value
245
254
sharding_data [_MESH_SHAPE ] = self .shape .tolist ()
246
255
sharding_data [_MESH_AXES ] = self .axis_names
256
+ if self .axis_types is not None :
257
+ sharding_data [_MESH_AXIS_TYPES ] = [str (a ) for a in self .axis_types ]
247
258
sharding_data [_PARTITION_SPEC ] = self .partition_spec
248
259
if self .device_mesh :
249
260
sharding_data [_DEVICE_MESH ] = dataclasses .asdict (self .device_mesh )
@@ -252,14 +263,15 @@ def to_serialized_string(self) -> str:
252
263
def __repr__ (self ):
253
264
return (
254
265
f'NamedShardingMetadata(shape={ self .shape } ,'
255
- f' axis_names={ self .axis_names } , partition_spec ={ self .partition_spec } ) '
256
- f' device_mesh={ self .device_mesh } '
266
+ f' axis_names={ self .axis_names } , axis_types ={ self .axis_types } , '
267
+ f' partition_spec= { self . partition_spec } ) device_mesh={ self .device_mesh } '
257
268
)
258
269
259
270
def __eq__ (self , other ):
260
271
return (
261
272
np .array_equal (self .shape , other .shape )
262
273
and self .axis_names == other .axis_names
274
+ and self .axis_types == other .axis_types
263
275
and self .partition_spec == other .partition_spec
264
276
and self .device_mesh == other .device_mesh
265
277
)
0 commit comments