Skip to content

Commit dc6f085

Browse files
author
Orbax Authors
committed
remove support for proto formatted xla flags
PiperOrigin-RevId: 811865727
1 parent c566855 commit dc6f085

File tree

2 files changed

+16
-63
lines changed

2 files changed

+16
-63
lines changed

model/orbax/experimental/model/core/python/compile_options_util.py

Lines changed: 12 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -49,20 +49,16 @@ def generate_tpu_compilation_env(
4949
)
5050
# Override with supplied XLA flags if any is provided.
5151
if xla_flags:
52-
is_proto_formatted = False if xla_flags[0].startswith('--') else True
53-
if is_proto_formatted:
54-
merge_proto_formatted_flags_into_compile_options(xla_flags, env)
55-
else:
56-
parsed_flags = {}
57-
for flag in xla_flags:
58-
if not flag.startswith('--'):
59-
raise ValueError(
60-
f"Flag {flag} does not start with '--'. All flags must be in the"
61-
' format of --flag_name=flag_value when using this format.'
62-
)
63-
flag_name, flag_value = flag[2:].split('=', 1)
64-
parsed_flags[flag_name] = flag_value
65-
merge_flags_into_compile_options(parsed_flags, env)
52+
parsed_flags = {}
53+
for flag in xla_flags:
54+
if not flag.startswith('--'):
55+
raise ValueError(
56+
f"Flag {flag} does not start with '--'. All flags must be in the"
57+
' format of --flag_name=flag_value.'
58+
)
59+
flag_name, flag_value = flag[2:].split('=', 1)
60+
parsed_flags[flag_name] = flag_value
61+
merge_flags_into_compile_options(parsed_flags, env)
6662

6763
# Pack the TPU compilation environment into a compilation env proto.
6864
any_proto = any_pb2.Any()
@@ -231,8 +227,8 @@ def merge_flags_into_compile_options(
231227
Args:
232228
xla_flags: A mapping of XLA flag names to their string values. These flags
233229
will be parsed and merged into the `env` proto.
234-
env: The TpuCompilationEnvironment proto to merge the flags into. This
235-
proto will be modified in place.
230+
env: The TpuCompilationEnvironment proto to merge the flags into. This proto
231+
will be modified in place.
236232
"""
237233
env_override = tpu_comp_env_pb2.TpuCompilationEnvironment()
238234
for flag_name, value in xla_flags.items():
@@ -245,20 +241,3 @@ def merge_flags_into_compile_options(
245241
# For scalar types, we can set the attribute directly.
246242
setattr(env_override, field_descriptor.name, parsed_value)
247243
env.MergeFrom(env_override)
248-
249-
250-
# TODO(b/438187387): remove this path and only allow the "--flag=value" format.
251-
def merge_proto_formatted_flags_into_compile_options(
252-
xla_flags: Sequence[str],
253-
env: tpu_comp_env_pb2.TpuCompilationEnvironment,
254-
):
255-
"""Merges flags into a proto."""
256-
env_override = tpu_comp_env_pb2.TpuCompilationEnvironment()
257-
xla_flags_str = '\n'.join(xla_flags)
258-
try:
259-
text_format.Parse(xla_flags_str, env_override)
260-
except text_format.ParseError as e:
261-
raise ValueError(
262-
f'Error parsing supplied XLA flag overrides {xla_flags_str}.'
263-
) from e
264-
env.MergeFrom(env_override)

model/orbax/experimental/model/core/python/compile_options_util_test.py

Lines changed: 4 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,6 @@ def _get_expected_proto_from_tpu_comp_env(field_str: str, proto_str: str):
4141
'xla_tpu_async_copy_bandwidth_scaling_factor': '0.19125064716453793',
4242
}
4343

44-
XLA_FLAGS_PROTO_FORMATTED = [
45-
'xla_jf_rematerialization_percent_shared_memory_limit: 99',
46-
'xla_tpu_allocate_scoped_vmem_at_same_offset: false',
47-
(
48-
'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers:'
49-
" 'NO_SCALE'"
50-
),
51-
'xla_tpu_memory_bound_loop_optimizer_options: {enabled:false}',
52-
'xla_tpu_async_copy_bandwidth_scaling_factor: 0.19125064716453793',
53-
]
5444

5545
EXPECTED_ENV = tpu_comp_env_pb2.TpuCompilationEnvironment(
5646
xla_jf_rematerialization_percent_shared_memory_limit=99,
@@ -122,19 +112,8 @@ def test_parse_flag_from_string_nonexistent_flag(self):
122112
with self.assertRaisesRegex(ValueError, 'Flag not found: nonexistent_flag'):
123113
compile_options_util.parse_flag_from_string('nonexistent_flag', 'value')
124114

125-
@parameterized.named_parameters(
126-
dict(
127-
testcase_name='dict_xla_flags',
128-
xla_flags=XLA_FLAGS_DICT,
129-
merge_fn=compile_options_util.merge_flags_into_compile_options,
130-
),
131-
dict(
132-
testcase_name='proto_formatted_xla_flags',
133-
xla_flags=XLA_FLAGS_PROTO_FORMATTED,
134-
merge_fn=compile_options_util.merge_proto_formatted_flags_into_compile_options,
135-
),
136-
)
137-
def test_merge_flags_into_compile_options(self, xla_flags, merge_fn):
115+
def test_merge_flags_into_compile_options(self):
116+
xla_flags = XLA_FLAGS_DICT
138117
# Initialize the environment with some values.
139118
env = tpu_comp_env_pb2.TpuCompilationEnvironment()
140119
# Values that should be overridden.
@@ -144,7 +123,7 @@ def test_merge_flags_into_compile_options(self, xla_flags, merge_fn):
144123
env.xla_tpu_wait_n_cycles_before_program_termination = 1234
145124

146125
# Merge the flags into the environment.
147-
merge_fn(xla_flags, env)
126+
compile_options_util.merge_flags_into_compile_options(xla_flags, env)
148127
self.assertEqual(
149128
env.xla_jf_rematerialization_percent_shared_memory_limit, 99
150129
)
@@ -170,11 +149,6 @@ def test_merge_flags_into_compile_options(self, xla_flags, merge_fn):
170149
xla_flags=[f'--{k}={v}' for k, v in XLA_FLAGS_DICT.items()],
171150
expected_env=EXPECTED_ENV,
172151
),
173-
dict(
174-
testcase_name='proto_formatted_xla_flags',
175-
xla_flags=XLA_FLAGS_PROTO_FORMATTED,
176-
expected_env=EXPECTED_ENV,
177-
),
178152
dict(
179153
testcase_name='no_xla_flags',
180154
xla_flags=None,
@@ -204,7 +178,7 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self):
204178
ValueError,
205179
'Flag xla_tpu_allocate_scoped_vmem_at_same_offset: false does not start'
206180
" with '--'. All flags must be in the format of"
207-
' --flag_name=flag_value when using this format.',
181+
' --flag_name=flag_value.',
208182
):
209183
compile_options_util.generate_tpu_compilation_env(
210184
xla_flags=[

0 commit comments

Comments
 (0)