Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions export/orbax/export/obm_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Export class that implements the save and load abstract class defined in Export Base for use with the Orbax Model export format."""

from collections.abc import Callable, Mapping, Sequence
import copy
import functools
import itertools
import os
Expand Down Expand Up @@ -45,14 +46,6 @@ def __init__(
serving_configs: Sequence[osc.ServingConfig],
):
"""Initializes the ObmExport class."""
if module.export_version != constants.ExportModelType.ORBAX_MODEL:
raise ValueError(
"JaxModule export version is not of type ORBAX_MODEL. Please use the"
" correct export_version. Expected ORBAX_MODEL, got"
f" {module.export_version}"
)

obm_model_module = module.export_module()

def save(
self,
Expand Down
Loading