Skip to content

Commit 29b3525

Browse files
author
Orbax Authors
committed
Add support for axis_types in NamedShardingMetadata.
PiperOrigin-RevId: 811194530
1 parent 964add5 commit 29b3525

File tree

4 files changed

+18
-3
lines changed

4 files changed

+18
-3
lines changed

checkpoint/orbax/checkpoint/_src/handlers/pytree_checkpoint_handler_test_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def test_sharding_variable_devices(self):
488488
shape=np.array([2]),
489489
axis_names=['x'],
490490
partition_spec=('x',),
491+
axis_types=(jax.sharding.AxisType.Auto,),
491492
device_mesh=sharding_metadata.DeviceMetadataMesh.from_jax_mesh(
492493
jax.sharding.Mesh(devices_subset, ('x',))
493494
),
@@ -496,6 +497,7 @@ def test_sharding_variable_devices(self):
496497
shape=np.array([8]),
497498
axis_names=['x'],
498499
partition_spec=('x',),
500+
axis_types=(jax.sharding.AxisType.Auto,),
499501
device_mesh=sharding_metadata.DeviceMetadataMesh.from_jax_mesh(
500502
jax.sharding.Mesh(jax.devices(), ('x',))
501503
),

checkpoint/orbax/checkpoint/_src/metadata/sharding.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,15 @@
2626
import jax
2727
import numpy as np
2828

29+
_AXIS_TYPE_MAP = {str(val): val for val in jax.sharding.AxisType}
2930
PartitionSpecElement = Union[None, str, Tuple[str, ...]]
3031

3132
_PARTITION_SPEC = 'partition_spec'
3233
_SHARDING = '_sharding'
3334
_SHARDING_TYPE = 'sharding_type'
3435
_DEVICE_STR = 'device_str'
3536
_MESH_AXES = 'axis_names'
37+
_MESH_AXIS_TYPES = 'axis_types'
3638
_MESH_SHAPE = 'shape'
3739
_DEVICES_SHAPE = 'shape'
3840
_DEVICE_MESH = 'device_mesh'
@@ -181,6 +183,7 @@ class NamedShardingMetadata(ShardingMetadata):
181183
partition_spec: Tuple[
182184
PartitionSpecElement, ...
183185
] # Each element is either ``None``, a string, or a tuple of strings.
186+
axis_types: Optional[Tuple[jax.sharding.AxisType, ...]] = None
184187

185188
# Optional device mesh. If it's None, use jax.devices(),
186189
# otherwise, the stored device_mesh will be used to recreate NamedSharding.
@@ -193,6 +196,7 @@ def from_jax_sharding(
193196
return cls(
194197
shape=np.array(list(jax_sharding.mesh.shape.values())),
195198
axis_names=list(jax_sharding.mesh.axis_names),
199+
axis_types=tuple(jax_sharding.mesh.axis_types),
196200
partition_spec=tuple(jax_sharding.spec),
197201
device_mesh=DeviceMetadataMesh.from_jax_mesh(jax_sharding.mesh),
198202
)
@@ -207,6 +211,7 @@ def to_jax_sharding(self) -> jax.sharding.NamedSharding:
207211
jax.sharding.Mesh(
208212
np.asarray(mesh_devices).reshape(self.shape),
209213
axis_names=self.axis_names,
214+
axis_types=self.axis_types,
210215
),
211216
spec=jax.sharding.PartitionSpec(*self.partition_spec),
212217
)
@@ -222,6 +227,9 @@ def from_deserialized_dict(
222227
):
223228
shape = np.array(deserialized_dict[_MESH_SHAPE])
224229
axis_names = list(deserialized_dict[_MESH_AXES])
230+
axis_types = None
231+
if axis_types_raw := deserialized_dict.get(_MESH_AXIS_TYPES):
232+
axis_types = tuple([_AXIS_TYPE_MAP[s] for s in axis_types_raw])
225233
partition_spec = tuple(deserialized_dict[_PARTITION_SPEC])
226234
if device_mesh_dic := deserialized_dict.get(_DEVICE_MESH):
227235
device_mesh = DeviceMetadataMesh.from_dict(device_mesh_dic)
@@ -231,6 +239,7 @@ def from_deserialized_dict(
231239
return cls(
232240
shape=shape,
233241
axis_names=axis_names,
242+
axis_types=axis_types,
234243
partition_spec=partition_spec,
235244
device_mesh=device_mesh,
236245
)
@@ -244,6 +253,8 @@ def to_serialized_string(self) -> str:
244253
sharding_data[_SHARDING_TYPE] = ShardingTypes.NAMED_SHARDING.value
245254
sharding_data[_MESH_SHAPE] = self.shape.tolist()
246255
sharding_data[_MESH_AXES] = self.axis_names
256+
if self.axis_types is not None:
257+
sharding_data[_MESH_AXIS_TYPES] = [str(a) for a in self.axis_types]
247258
sharding_data[_PARTITION_SPEC] = self.partition_spec
248259
if self.device_mesh:
249260
sharding_data[_DEVICE_MESH] = dataclasses.asdict(self.device_mesh)
@@ -252,14 +263,15 @@ def to_serialized_string(self) -> str:
252263
def __repr__(self):
253264
return (
254265
f'NamedShardingMetadata(shape={self.shape},'
255-
f' axis_names={self.axis_names}, partition_spec={self.partition_spec})'
256-
f' device_mesh={self.device_mesh}'
266+
f' axis_names={self.axis_names}, axis_types={self.axis_types},'
267+
f' partition_spec={self.partition_spec}) device_mesh={self.device_mesh}'
257268
)
258269

259270
def __eq__(self, other):
260271
return (
261272
np.array_equal(self.shape, other.shape)
262273
and self.axis_names == other.axis_names
274+
and self.axis_types == other.axis_types
263275
and self.partition_spec == other.partition_spec
264276
and self.device_mesh == other.device_mesh
265277
)

checkpoint/orbax/checkpoint/_src/metadata/sharding_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def test_convert_between_jax_named_sharding_and_sharding_metadata(self):
3030
shape=np.array([1]),
3131
axis_names=(["x"]),
3232
partition_spec=(None,),
33+
axis_types=(jax.sharding.AxisType.Auto,),
3334
device_mesh=sharding_metadata.DeviceMetadataMesh.from_jax_mesh(
3435
jax_sharding.mesh
3536
),

checkpoint/orbax/checkpoint/_src/serialization/serialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ async def read_and_create_array(
508508
collections.defaultdict(list)
509509
)
510510
for d, idx in _get_device_to_index_map(global_shape, sharding).items():
511-
if d in sharding._addressable_device_assignment: # pylint: disable=protected-access
511+
if d in jax.local_devices(): # pylint: disable=protected-access
512512
local_indices_devices_map[
513513
np_utils.to_hashable_index(idx, shape=global_shape)
514514
].append(d)

0 commit comments

Comments
 (0)