Skip to content

Commit cbfef40

Browse files
author
Orbax Authors
committed
Add unit tests for saving/loading model
PiperOrigin-RevId: 811851235
1 parent dc6f085 commit cbfef40

14 files changed

+726
-278
lines changed

model/orbax/experimental/model/core/python/__init__.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
from orbax.experimental.model.core.python.function import ShloShape
3636
from orbax.experimental.model.core.python.function import ShloTensorSpec
3737
from orbax.experimental.model.core.python.manifest_constants import *
38-
from orbax.experimental.model.core.python.save_lib import GlobalSupplemental
39-
from orbax.experimental.model.core.python.save_lib import save
40-
from orbax.experimental.model.core.python.save_lib import SaveOptions
38+
from orbax.experimental.model.core.python.persistence_lib import GlobalSupplemental
39+
from orbax.experimental.model.core.python.persistence_lib import load
40+
from orbax.experimental.model.core.python.persistence_lib import save
41+
from orbax.experimental.model.core.python.persistence_lib import SaveOptions
4142
from orbax.experimental.model.core.python.saveable import Saveable
4243
from orbax.experimental.model.core.python.serializable_function import SerializableFunction
4344
from orbax.experimental.model.core.python.shlo_function import ShloFunction

model/orbax/experimental/model/core/python/file_utils.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@
1515
"""File utilities."""
1616

1717
import contextlib
18-
18+
import os
1919

2020
_file_opener = open
21+
_mkdir_p = lambda path: os.makedirs(path, exist_ok=True)
22+
2123

2224

2325
@contextlib.contextmanager
@@ -28,3 +30,8 @@ def open_file(filename: str, mode: str):
2830
yield f
2931
finally:
3032
f.close()
33+
34+
35+
def mkdir_p(path: str) -> None:
36+
"""Creates a directory, creating parent directories as needed."""
37+
_mkdir_p(path)

model/orbax/experimental/model/core/python/manifest_constants.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,15 @@
1414

1515
"""Manifest model format constants."""
1616

17-
MANIFEST_VERSION_FILENAME = 'orbax_model_version.txt'
17+
# The filename of the model version metadata file relative to the save
18+
# directory.
19+
MODEL_VERSION_FILENAME = 'orbax_model_version.txt'
1820

19-
# The file path of the manifest proto file
20-
MANIFEST_FILE_PATH_KEY = 'manifest_file_path'
21-
# TODO(b/439870164): Update the `MANIFEST_FILENAME` to be `MANIFEST_FILE_PATH`
22-
# and treat it as a configurable path
23-
MANIFEST_FILENAME = 'manifest.pb'
21+
# The file path of the manifest proto file relative to the save directory.
22+
MANIFEST_FILE_PATH = 'manifest.pb'
2423

25-
# The version of the manifest
26-
VERSION_KEY = 'version'
24+
# The version of the manifest.
2725
MANIFEST_VERSION = '0.0.1'
2826

29-
# The mime type of the manifest proto file
30-
MIME_TYPE_KEY = 'mime_type'
27+
# The mime type of the manifest proto file.
3128
MANIFEST_MIME_TYPE = 'application/protobuf; type=orbax_model_manifest.Manifest'

model/orbax/experimental/model/core/python/manifest_util.py

Lines changed: 2 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from collections.abc import Mapping, Sequence
1919
from absl import logging
2020
from orbax.experimental.model.core.protos import manifest_pb2
21-
from orbax.experimental.model.core.python import manifest_constants
2221
from orbax.experimental.model.core.python import unstructured_data
2322
from orbax.experimental.model.core.python.device_assignment import DeviceAssignment
2423
from orbax.experimental.model.core.python.function import Function
@@ -29,6 +28,7 @@
2928
from orbax.experimental.model.core.python.unstructured_data import UnstructuredData
3029
from orbax.experimental.model.core.python.value import ExternalValue
3130

31+
3232
def _build_function(
3333
fn: Function,
3434
path: str,
@@ -63,7 +63,7 @@ def _build_function(
6363
supp_proto = supp.proto
6464
if supp.ext_name is not None:
6565
filename = unstructured_data.build_filename_from_extension(
66-
name + "_supplemental", supp.ext_name
66+
name + "_" + supp_name + "_supplemental", supp.ext_name
6767
)
6868
supp_proto = unstructured_data.write_inlined_data_to_file(
6969
supp_proto, path, filename
@@ -115,28 +115,6 @@ def _is_seq_of_functions(obj: Saveable) -> bool:
115115
)
116116

117117

118-
def build_manifest_version_file() -> str:
119-
"""Builds a manifest version file content."""
120-
121-
# TODO(b/365967674): Remove this check once the manifest filename is
122-
# configurable by the manifest version file. Currently, the manifest filename
123-
# is hardcoded to "manifest.pb" in OBM & JSV codebase and that needs to be
124-
# updated first.
125-
if manifest_constants.MANIFEST_FILENAME != "manifest.pb":
126-
raise ValueError(
127-
"Currently, only manifest.pb is supported as the manifest filename."
128-
)
129-
130-
return (
131-
f"{manifest_constants.MANIFEST_FILE_PATH_KEY}:"
132-
f' "{manifest_constants.MANIFEST_FILENAME}"\n'
133-
f"{manifest_constants.VERSION_KEY}:"
134-
f' "{manifest_constants.MANIFEST_VERSION}"\n'
135-
f"{manifest_constants.MIME_TYPE_KEY}:"
136-
f' "{manifest_constants.MANIFEST_MIME_TYPE}"\n'
137-
)
138-
139-
140118
def build_manifest_proto(
141119
obm_module: dict[str, Saveable],
142120
path: str,

model/orbax/experimental/model/core/python/manifest_util_test.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,5 @@ def test_build_device_assignment_by_coords_proto(self):
6464
self.assertEqual(device.core_on_chip, 0) # Proto default
6565

6666

67-
def test_build_manifest_version_file_content(self):
68-
content = manifest_util.build_manifest_version_file()
69-
expected_content = (
70-
'manifest_file_path: "manifest.pb"\n'
71-
'version: "0.0.1"\n'
72-
'mime_type: "application/protobuf; type=orbax_model_manifest.Manifest"\n'
73-
)
74-
self.assertEqual(content, expected_content)
75-
76-
7767
if __name__ == '__main__':
7868
absltest.main()
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Model version metadata and its serialization."""
16+
17+
import dataclasses
18+
from orbax.experimental.model.core.python import file_utils
19+
20+
21+
@dataclasses.dataclass
22+
class ModelVersion:
23+
"""Model version metadata."""
24+
25+
_VERSION_KEY = 'version'
26+
_MIME_TYPE_KEY = 'mime_type'
27+
_MANIFEST_FILE_PATH_KEY = 'manifest_file_path'
28+
29+
version: str
30+
mime_type: str
31+
manifest_file_path: str
32+
33+
def save(self, path: str) -> None:
34+
"""Saves the model version metadata to a file."""
35+
with file_utils.open_file(path, 'w') as f:
36+
f.write(f'{self._MANIFEST_FILE_PATH_KEY}: "{self.manifest_file_path}"\n')
37+
f.write(f'{self._VERSION_KEY}: "{self.version}"\n')
38+
f.write(f'{self._MIME_TYPE_KEY}: "{self.mime_type}"\n')
39+
40+
@classmethod
41+
def load(cls, path: str) -> 'ModelVersion':
42+
"""Loads the model version metadata from a file."""
43+
44+
version = ''
45+
mime_type = ''
46+
manifest_file_path = ''
47+
48+
with file_utils.open_file(path, 'r') as f:
49+
for line in f:
50+
line = line.strip()
51+
if not line:
52+
continue
53+
if ':' not in line:
54+
raise ValueError(f'Malformed line: {line}')
55+
56+
key, value = line.split(':', 1)
57+
key = key.strip()
58+
value = value.strip()
59+
60+
if not value.startswith('"') or not value.endswith('"'):
61+
raise ValueError('All values must be double-quoted')
62+
63+
value = value[1:-1]
64+
if key == cls._MANIFEST_FILE_PATH_KEY:
65+
manifest_file_path = value
66+
elif key == cls._VERSION_KEY:
67+
version = value
68+
elif key == cls._MIME_TYPE_KEY:
69+
mime_type = value
70+
else:
71+
raise ValueError(f'Unknown key: {key}')
72+
73+
if not version:
74+
raise ValueError('Version is empty')
75+
if not mime_type:
76+
raise ValueError('MIME type is empty')
77+
if not manifest_file_path:
78+
raise ValueError('Manifest file path is empty')
79+
80+
return cls(
81+
version=version,
82+
mime_type=mime_type,
83+
manifest_file_path=manifest_file_path,
84+
)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
# Copyright 2025 The Orbax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import os
16+
from absl.testing import absltest
17+
from orbax.experimental.model.core.python import file_utils
18+
from orbax.experimental.model.core.python import metadata
19+
20+
21+
class MetadataTest(absltest.TestCase):
22+
23+
def test_save_and_load(self):
24+
tempdir = self.create_tempdir().full_path
25+
path = os.path.join(tempdir, 'orbax_model_version.txt')
26+
mv = metadata.ModelVersion(
27+
version='1', mime_type='test_mime_type', manifest_file_path='test/path'
28+
)
29+
mv.save(path)
30+
mv_loaded = metadata.ModelVersion.load(path)
31+
self.assertEqual(mv_loaded, mv)
32+
33+
def test_load_fails_with_unknown_fields(self):
34+
tempdir = self.create_tempdir().full_path
35+
path = os.path.join(tempdir, 'orbax_model_version.txt')
36+
file_content = """
37+
manifest_file_path: "test/path"
38+
version: "0.0.1"
39+
mime_type: "test_mime_type; application/foo"
40+
unknown_field: unknown_value
41+
"""
42+
with file_utils.open_file(path, 'w') as f:
43+
f.write(file_content)
44+
45+
with self.assertRaises(ValueError):
46+
metadata.ModelVersion.load(path)
47+
48+
def test_load_fails_with_single_quoted_values(self):
49+
tempdir = self.create_tempdir().full_path
50+
path = os.path.join(tempdir, 'orbax_model_version.txt')
51+
file_content = """
52+
manifest_file_path: "test/path"
53+
version: '0.0.1'
54+
mime_type: "test_mime_type; application/foo"
55+
"""
56+
with file_utils.open_file(path, 'w') as f:
57+
f.write(file_content)
58+
59+
with self.assertRaises(ValueError):
60+
metadata.ModelVersion.load(path)
61+
62+
def test_load_fails_with_malformed_file(self):
63+
tempdir = self.create_tempdir().full_path
64+
path = os.path.join(tempdir, 'orbax_model_version.txt')
65+
file_content = """
66+
manifest_file_path: "test/path"
67+
malformed_line_no_separator
68+
"""
69+
with file_utils.open_file(path, 'w') as f:
70+
f.write(file_content)
71+
72+
with self.assertRaises(ValueError):
73+
metadata.ModelVersion.load(path)
74+
75+
def test_missing_version(self):
76+
tempdir = self.create_tempdir().full_path
77+
path = os.path.join(tempdir, 'orbax_model_version.txt')
78+
file_content = """
79+
manifest_file_path: "test/path"
80+
mime_type: "test_mime_type; application/foo"
81+
"""
82+
with file_utils.open_file(path, 'w') as f:
83+
f.write(file_content)
84+
85+
with self.assertRaisesRegex(ValueError, 'Version is empty'):
86+
metadata.ModelVersion.load(path)
87+
88+
def test_missing_mime_type(self):
89+
tempdir = self.create_tempdir().full_path
90+
path = os.path.join(tempdir, 'orbax_model_version.txt')
91+
file_content = """
92+
version: "0.0.1"
93+
manifest_file_path: "test/path"
94+
"""
95+
with file_utils.open_file(path, 'w') as f:
96+
f.write(file_content)
97+
98+
with self.assertRaisesRegex(ValueError, 'MIME type is empty'):
99+
metadata.ModelVersion.load(path)
100+
101+
def test_missing_manifest_file_path(self):
102+
tempdir = self.create_tempdir().full_path
103+
path = os.path.join(tempdir, 'orbax_model_version.txt')
104+
file_content = """
105+
version: "0.0.1"
106+
mime_type: "test_mime_type; application/foo"
107+
"""
108+
with file_utils.open_file(path, 'w') as f:
109+
f.write(file_content)
110+
111+
with self.assertRaisesRegex(ValueError, 'Manifest file path is empty'):
112+
metadata.ModelVersion.load(path)
113+
114+
115+
if __name__ == '__main__':
116+
absltest.main()

0 commit comments

Comments
 (0)