Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions model/orbax/experimental/model/core/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,10 @@
from orbax.experimental.model.core.python.function import ShloShape
from orbax.experimental.model.core.python.function import ShloTensorSpec
from orbax.experimental.model.core.python.manifest_constants import *
from orbax.experimental.model.core.python.save_lib import GlobalSupplemental
from orbax.experimental.model.core.python.save_lib import save
from orbax.experimental.model.core.python.save_lib import SaveOptions
from orbax.experimental.model.core.python.persistence_lib import GlobalSupplemental
from orbax.experimental.model.core.python.persistence_lib import load
from orbax.experimental.model.core.python.persistence_lib import save
from orbax.experimental.model.core.python.persistence_lib import SaveOptions
from orbax.experimental.model.core.python.saveable import Saveable
from orbax.experimental.model.core.python.serializable_function import SerializableFunction
from orbax.experimental.model.core.python.shlo_function import ShloFunction
Expand Down
9 changes: 8 additions & 1 deletion model/orbax/experimental/model/core/python/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@
"""File utilities."""

import contextlib

import os

_file_opener = open
_mkdir_p = lambda path: os.makedirs(path, exist_ok=True)



@contextlib.contextmanager
Expand All @@ -28,3 +30,8 @@ def open_file(filename: str, mode: str):
yield f
finally:
f.close()


def mkdir_p(path: str) -> None:
"""Creates a directory, creating parent directories as needed."""
_mkdir_p(path)
17 changes: 7 additions & 10 deletions model/orbax/experimental/model/core/python/manifest_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@

"""Manifest model format constants."""

MANIFEST_VERSION_FILENAME = 'orbax_model_version.txt'
# The filename of the model version metadata file relative to the save
# directory.
MODEL_VERSION_FILENAME = 'orbax_model_version.txt'

# The file path of the manifest proto file
MANIFEST_FILE_PATH_KEY = 'manifest_file_path'
# TODO(b/439870164): Update the `MANIFEST_FILENAME` to be `MANIFEST_FILE_PATH`
# and treat it as a configurable path
MANIFEST_FILENAME = 'manifest.pb'
# The file path of the manifest proto file relative to the save directory.
MANIFEST_FILE_PATH = 'manifest.pb'

# The version of the manifest
VERSION_KEY = 'version'
# The version of the manifest.
MANIFEST_VERSION = '0.0.1'

# The mime type of the manifest proto file
MIME_TYPE_KEY = 'mime_type'
# The mime type of the manifest proto file.
MANIFEST_MIME_TYPE = 'application/protobuf; type=orbax_model_manifest.Manifest'
26 changes: 2 additions & 24 deletions model/orbax/experimental/model/core/python/manifest_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from collections.abc import Mapping, Sequence
from absl import logging
from orbax.experimental.model.core.protos import manifest_pb2
from orbax.experimental.model.core.python import manifest_constants
from orbax.experimental.model.core.python import unstructured_data
from orbax.experimental.model.core.python.device_assignment import DeviceAssignment
from orbax.experimental.model.core.python.function import Function
Expand All @@ -29,6 +28,7 @@
from orbax.experimental.model.core.python.unstructured_data import UnstructuredData
from orbax.experimental.model.core.python.value import ExternalValue


def _build_function(
fn: Function,
path: str,
Expand Down Expand Up @@ -63,7 +63,7 @@ def _build_function(
supp_proto = supp.proto
if supp.ext_name is not None:
filename = unstructured_data.build_filename_from_extension(
name + "_supplemental", supp.ext_name
name + "_" + supp_name + "_supplemental", supp.ext_name
)
supp_proto = unstructured_data.write_inlined_data_to_file(
supp_proto, path, filename
Expand Down Expand Up @@ -115,28 +115,6 @@ def _is_seq_of_functions(obj: Saveable) -> bool:
)


def build_manifest_version_file() -> str:
"""Builds a manifest version file content."""

# TODO(b/365967674): Remove this check once the manifest filename is
# configurable by the manifest version file. Currently, the manifest filename
# is hardcoded to "manifest.pb" in OBM & JSV codebase and that needs to be
# updated first.
if manifest_constants.MANIFEST_FILENAME != "manifest.pb":
raise ValueError(
"Currently, only manifest.pb is supported as the manifest filename."
)

return (
f"{manifest_constants.MANIFEST_FILE_PATH_KEY}:"
f' "{manifest_constants.MANIFEST_FILENAME}"\n'
f"{manifest_constants.VERSION_KEY}:"
f' "{manifest_constants.MANIFEST_VERSION}"\n'
f"{manifest_constants.MIME_TYPE_KEY}:"
f' "{manifest_constants.MANIFEST_MIME_TYPE}"\n'
)


def build_manifest_proto(
obm_module: dict[str, Saveable],
path: str,
Expand Down
10 changes: 0 additions & 10 deletions model/orbax/experimental/model/core/python/manifest_util_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,15 +64,5 @@ def test_build_device_assignment_by_coords_proto(self):
self.assertEqual(device.core_on_chip, 0) # Proto default


def test_build_manifest_version_file_content(self):
content = manifest_util.build_manifest_version_file()
expected_content = (
'manifest_file_path: "manifest.pb"\n'
'version: "0.0.1"\n'
'mime_type: "application/protobuf; type=orbax_model_manifest.Manifest"\n'
)
self.assertEqual(content, expected_content)


if __name__ == '__main__':
absltest.main()
84 changes: 84 additions & 0 deletions model/orbax/experimental/model/core/python/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Model version metadata and its serialization."""

import dataclasses
from orbax.experimental.model.core.python import file_utils


@dataclasses.dataclass
class ModelVersion:
"""Model version metadata."""

_VERSION_KEY = 'version'
_MIME_TYPE_KEY = 'mime_type'
_MANIFEST_FILE_PATH_KEY = 'manifest_file_path'

version: str
mime_type: str
manifest_file_path: str

def save(self, path: str) -> None:
"""Saves the model version metadata to a file."""
with file_utils.open_file(path, 'w') as f:
f.write(f'{self._MANIFEST_FILE_PATH_KEY}: "{self.manifest_file_path}"\n')
f.write(f'{self._VERSION_KEY}: "{self.version}"\n')
f.write(f'{self._MIME_TYPE_KEY}: "{self.mime_type}"\n')

@classmethod
def load(cls, path: str) -> 'ModelVersion':
"""Loads the model version metadata from a file."""

version = ''
mime_type = ''
manifest_file_path = ''

with file_utils.open_file(path, 'r') as f:
for line in f:
line = line.strip()
if not line:
continue
if ':' not in line:
raise ValueError(f'Malformed line: {line}')

key, value = line.split(':', 1)
key = key.strip()
value = value.strip()

if not value.startswith('"') or not value.endswith('"'):
raise ValueError('All values must be double-quoted')

value = value[1:-1]
if key == cls._MANIFEST_FILE_PATH_KEY:
manifest_file_path = value
elif key == cls._VERSION_KEY:
version = value
elif key == cls._MIME_TYPE_KEY:
mime_type = value
else:
raise ValueError(f'Unknown key: {key}')

if not version:
raise ValueError('Version is empty')
if not mime_type:
raise ValueError('MIME type is empty')
if not manifest_file_path:
raise ValueError('Manifest file path is empty')

return cls(
version=version,
mime_type=mime_type,
manifest_file_path=manifest_file_path,
)
116 changes: 116 additions & 0 deletions model/orbax/experimental/model/core/python/metadata_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright 2025 The Orbax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from absl.testing import absltest
from orbax.experimental.model.core.python import file_utils
from orbax.experimental.model.core.python import metadata


class MetadataTest(absltest.TestCase):

def test_save_and_load(self):
tempdir = self.create_tempdir().full_path
path = os.path.join(tempdir, 'orbax_model_version.txt')
mv = metadata.ModelVersion(
version='1', mime_type='test_mime_type', manifest_file_path='test/path'
)
mv.save(path)
mv_loaded = metadata.ModelVersion.load(path)
self.assertEqual(mv_loaded, mv)

def test_load_fails_with_unknown_fields(self):
tempdir = self.create_tempdir().full_path
path = os.path.join(tempdir, 'orbax_model_version.txt')
file_content = """
manifest_file_path: "test/path"
version: "0.0.1"
mime_type: "test_mime_type; application/foo"
unknown_field: unknown_value
"""
with file_utils.open_file(path, 'w') as f:
f.write(file_content)

with self.assertRaises(ValueError):
metadata.ModelVersion.load(path)

def test_load_fails_with_single_quoted_values(self):
tempdir = self.create_tempdir().full_path
path = os.path.join(tempdir, 'orbax_model_version.txt')
file_content = """
manifest_file_path: "test/path"
version: '0.0.1'
mime_type: "test_mime_type; application/foo"
"""
with file_utils.open_file(path, 'w') as f:
f.write(file_content)

with self.assertRaises(ValueError):
metadata.ModelVersion.load(path)

def test_load_fails_with_malformed_file(self):
tempdir = self.create_tempdir().full_path
path = os.path.join(tempdir, 'orbax_model_version.txt')
file_content = """
manifest_file_path: "test/path"
malformed_line_no_separator
"""
with file_utils.open_file(path, 'w') as f:
f.write(file_content)

with self.assertRaises(ValueError):
metadata.ModelVersion.load(path)

def test_missing_version(self):
tempdir = self.create_tempdir().full_path
path = os.path.join(tempdir, 'orbax_model_version.txt')
file_content = """
manifest_file_path: "test/path"
mime_type: "test_mime_type; application/foo"
"""
with file_utils.open_file(path, 'w') as f:
f.write(file_content)

with self.assertRaisesRegex(ValueError, 'Version is empty'):
metadata.ModelVersion.load(path)

def test_missing_mime_type(self):
tempdir = self.create_tempdir().full_path
path = os.path.join(tempdir, 'orbax_model_version.txt')
file_content = """
version: "0.0.1"
manifest_file_path: "test/path"
"""
with file_utils.open_file(path, 'w') as f:
f.write(file_content)

with self.assertRaisesRegex(ValueError, 'MIME type is empty'):
metadata.ModelVersion.load(path)

def test_missing_manifest_file_path(self):
tempdir = self.create_tempdir().full_path
path = os.path.join(tempdir, 'orbax_model_version.txt')
file_content = """
version: "0.0.1"
mime_type: "test_mime_type; application/foo"
"""
with file_utils.open_file(path, 'w') as f:
f.write(file_content)

with self.assertRaisesRegex(ValueError, 'Manifest file path is empty'):
metadata.ModelVersion.load(path)


if __name__ == '__main__':
absltest.main()
Loading
Loading