Skip to content

Commit 6620716

Browse files
cpgaffney1Orbax Authors
authored andcommitted
Internal change.
PiperOrigin-RevId: 808746490
1 parent dc6f085 commit 6620716

File tree

2 files changed

+239
-2
lines changed

2 files changed

+239
-2
lines changed

checkpoint/orbax/checkpoint/test_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,9 +231,9 @@ def assert_array_equal(testclass, v_expected, v_actual):
231231
for shard_expected, shard_actual in zip(
232232
v_expected.addressable_shards, v_actual.addressable_shards
233233
):
234-
np.testing.assert_array_equal(shard_expected.data, shard_actual.data)
234+
np.testing.assert_array_equal(shard_actual.data, shard_expected.data)
235235
elif isinstance(v_expected, (np.ndarray, jnp.ndarray)):
236-
np.testing.assert_array_equal(v_expected, v_actual)
236+
np.testing.assert_array_equal(v_actual, v_expected)
237237
else:
238238
testclass.assertEqual(v_expected, v_actual)
239239

Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Test class for process-local paths."""
16+
17+
from __future__ import annotations
18+
19+
import typing
20+
from typing import Iterator, Protocol
21+
22+
from etils import epath
23+
from orbax.checkpoint._src.multihost import multihost
24+
25+
26+
@typing.runtime_checkable
27+
class LocalPath(Protocol):
28+
"""A Path implementation for testing process-local paths.
29+
30+
In the future, this class may more completely provide all functions and
31+
properties of a pathlib Path, but for now, it only provides the minimum
32+
needed to support relevant tests.
33+
34+
This class is intended to receive a base path, and append a process-specific
35+
suffix to it when path operations are performed (the appending should be
36+
delayed as much as possible).
37+
38+
Operations that combine two LocalPaths should not re-append the suffix.
39+
40+
One may ask - why not just construct a path initially with some
41+
process-specific suffix appended, and use that for testing? This works for
42+
multi-controller, but not for single-controller, where paths are typically
43+
constructed in the
44+
controller (single-process) and passed to workers (multi-process). The
45+
process index must be appended when path operations are performed.
46+
"""
47+
48+
@property
49+
def path(self) -> epath.Path:
50+
...
51+
52+
def exists(self) -> bool:
53+
"""Returns True if self exists."""
54+
...
55+
56+
def is_dir(self) -> bool:
57+
"""Returns True if self is a dir."""
58+
...
59+
60+
def is_file(self) -> bool:
61+
"""Returns True if self is a file."""
62+
...
63+
64+
def iterdir(self) -> Iterator[LocalPath]:
65+
"""Iterates over the directory."""
66+
...
67+
68+
def glob(self, pattern: str) -> Iterator[LocalPath]:
69+
"""Yields all matching files (of any kind)."""
70+
...
71+
72+
def read_bytes(self) -> bytes:
73+
"""Reads contents of self as bytes."""
74+
...
75+
76+
def read_text(self, encoding: str | None = None) -> str:
77+
"""Reads contents of self as a string."""
78+
...
79+
80+
def mkdir(
81+
self,
82+
mode: int | None = None,
83+
parents: bool = False,
84+
exist_ok: bool = False,
85+
) -> None:
86+
"""Create a new directory at this given path."""
87+
...
88+
89+
def rmdir(self) -> None:
90+
"""Remove the empty directory at this given path."""
91+
...
92+
93+
def rmtree(self, missing_ok: bool = False) -> None:
94+
"""Remove the directory, including all sub-files."""
95+
...
96+
97+
def unlink(self, missing_ok: bool = False) -> None:
98+
"""Remove this file or symbolic link."""
99+
...
100+
101+
def write_bytes(self, data: bytes) -> int:
102+
"""Writes content as bytes."""
103+
...
104+
105+
def write_text(
106+
self,
107+
data: str,
108+
encoding: str | None = None,
109+
errors: str | None = None,
110+
) -> int:
111+
"""Writes content as str."""
112+
...
113+
114+
### PosixPath methods ###
115+
116+
def as_posix(self) -> str:
117+
...
118+
119+
def __truediv__(self, key: epath.PathLike) -> epath.Path:
120+
...
121+
122+
@property
123+
def name(self) -> str:
124+
...
125+
126+
@property
127+
def parent(self) -> epath.Path:
128+
...
129+
130+
131+
LocalPathLike = LocalPath | str
132+
133+
134+
def _resolve_local_path_like(path: LocalPathLike) -> epath.Path:
135+
if isinstance(path, LocalPath):
136+
return typing.cast(LocalPath, path).path
137+
return epath.Path(path)
138+
139+
140+
class _LocalPathImpl(LocalPath):
141+
"""etils.epath.Path implementation for multiprocess tests."""
142+
143+
def __init__(self, path: epath.PathLike):
144+
self._path = epath.Path(path)
145+
146+
def __repr__(self) -> str:
147+
return f'{self.__class__.__name__}({self.path})'
148+
149+
@property
150+
def base_path(self) -> epath.Path:
151+
return self._path
152+
153+
@property
154+
def path(self) -> epath.Path:
155+
return self.base_path / str(f'local_{multihost.process_index()}')
156+
157+
def exists(self) -> bool:
158+
"""Returns True if self exists."""
159+
return self.path.exists()
160+
161+
def is_dir(self) -> bool:
162+
"""Returns True if self is a dir."""
163+
return self.path.is_dir()
164+
165+
def is_file(self) -> bool:
166+
"""Returns True if self is a file."""
167+
return self.path.is_file()
168+
169+
def iterdir(self) -> Iterator[LocalPath]:
170+
"""Iterates over the directory."""
171+
return (_LocalPathImpl(p) for p in self.path.iterdir())
172+
173+
def glob(self, pattern: str) -> Iterator[LocalPath]:
174+
"""Yields all matching files (of any kind)."""
175+
return (_LocalPathImpl(p) for p in self.path.glob(pattern))
176+
177+
def read_bytes(self) -> bytes:
178+
"""Reads contents of self as bytes."""
179+
return self.path.read_bytes()
180+
181+
def read_text(self, encoding: str | None = None) -> str:
182+
"""Reads contents of self as a string."""
183+
return self.path.read_text(encoding=encoding)
184+
185+
def mkdir(
186+
self,
187+
mode: int | None = None,
188+
parents: bool = False,
189+
exist_ok: bool = False,
190+
) -> None:
191+
"""Create a new directory at this given path."""
192+
del mode # mode is not supported by epath.Path.mkdir
193+
self.path.mkdir(parents=parents, exist_ok=exist_ok)
194+
195+
def rmdir(self) -> None:
196+
"""Remove the empty directory at this given path."""
197+
self.path.rmdir()
198+
199+
def rmtree(self, missing_ok: bool = False) -> None:
200+
"""Remove the directory, including all sub-files."""
201+
self.path.rmtree(missing_ok=missing_ok)
202+
203+
def unlink(self, missing_ok: bool = False) -> None:
204+
"""Remove this file or symbolic link."""
205+
self.path.unlink(missing_ok=missing_ok)
206+
207+
def write_bytes(self, data: bytes) -> int:
208+
"""Writes content as bytes."""
209+
return self.path.write_bytes(data)
210+
211+
def write_text(
212+
self,
213+
data: str,
214+
encoding: str | None = None,
215+
errors: str | None = None,
216+
) -> int:
217+
"""Writes content as str."""
218+
return self.path.write_text(data, encoding=encoding, errors=errors)
219+
220+
def as_posix(self) -> str:
221+
return self.path.as_posix()
222+
223+
def __truediv__(self, key: epath.PathLike) -> epath.Path:
224+
return self.path / key
225+
226+
@property
227+
def name(self) -> str:
228+
return self.path.name
229+
230+
@property
231+
def parent(self) -> epath.Path:
232+
return self.path.parent
233+
234+
235+
def local_path(path: epath.PathLike) -> LocalPath:
236+
"""Returns a LocalPath for the given path."""
237+
return _LocalPathImpl(path)

0 commit comments

Comments
 (0)