@@ -41,16 +41,6 @@ def _get_expected_proto_from_tpu_comp_env(field_str: str, proto_str: str):
41
41
'xla_tpu_async_copy_bandwidth_scaling_factor' : '0.19125064716453793' ,
42
42
}
43
43
44
- XLA_FLAGS_PROTO_FORMATTED = [
45
- 'xla_jf_rematerialization_percent_shared_memory_limit: 99' ,
46
- 'xla_tpu_allocate_scoped_vmem_at_same_offset: false' ,
47
- (
48
- 'xla_tpu_alternate_memory_benefit_scaling_factor_for_large_buffers:'
49
- " 'NO_SCALE'"
50
- ),
51
- 'xla_tpu_memory_bound_loop_optimizer_options: {enabled:false}' ,
52
- 'xla_tpu_async_copy_bandwidth_scaling_factor: 0.19125064716453793' ,
53
- ]
54
44
55
45
EXPECTED_ENV = tpu_comp_env_pb2 .TpuCompilationEnvironment (
56
46
xla_jf_rematerialization_percent_shared_memory_limit = 99 ,
@@ -122,19 +112,8 @@ def test_parse_flag_from_string_nonexistent_flag(self):
122
112
with self .assertRaisesRegex (ValueError , 'Flag not found: nonexistent_flag' ):
123
113
compile_options_util .parse_flag_from_string ('nonexistent_flag' , 'value' )
124
114
125
- @parameterized .named_parameters (
126
- dict (
127
- testcase_name = 'dict_xla_flags' ,
128
- xla_flags = XLA_FLAGS_DICT ,
129
- merge_fn = compile_options_util .merge_flags_into_compile_options ,
130
- ),
131
- dict (
132
- testcase_name = 'proto_formatted_xla_flags' ,
133
- xla_flags = XLA_FLAGS_PROTO_FORMATTED ,
134
- merge_fn = compile_options_util .merge_proto_formatted_flags_into_compile_options ,
135
- ),
136
- )
137
- def test_merge_flags_into_compile_options (self , xla_flags , merge_fn ):
115
+ def test_merge_flags_into_compile_options (self ):
116
+ xla_flags = XLA_FLAGS_DICT
138
117
# Initialize the environment with some values.
139
118
env = tpu_comp_env_pb2 .TpuCompilationEnvironment ()
140
119
# Values that should be overridden.
@@ -144,7 +123,7 @@ def test_merge_flags_into_compile_options(self, xla_flags, merge_fn):
144
123
env .xla_tpu_wait_n_cycles_before_program_termination = 1234
145
124
146
125
# Merge the flags into the environment.
147
- merge_fn (xla_flags , env )
126
+ compile_options_util . merge_flags_into_compile_options (xla_flags , env )
148
127
self .assertEqual (
149
128
env .xla_jf_rematerialization_percent_shared_memory_limit , 99
150
129
)
@@ -170,11 +149,6 @@ def test_merge_flags_into_compile_options(self, xla_flags, merge_fn):
170
149
xla_flags = [f'--{ k } ={ v } ' for k , v in XLA_FLAGS_DICT .items ()],
171
150
expected_env = EXPECTED_ENV ,
172
151
),
173
- dict (
174
- testcase_name = 'proto_formatted_xla_flags' ,
175
- xla_flags = XLA_FLAGS_PROTO_FORMATTED ,
176
- expected_env = EXPECTED_ENV ,
177
- ),
178
152
dict (
179
153
testcase_name = 'no_xla_flags' ,
180
154
xla_flags = None ,
@@ -204,7 +178,7 @@ def test_generate_tpu_compilation_env_invalid_flag_format(self):
204
178
ValueError ,
205
179
'Flag xla_tpu_allocate_scoped_vmem_at_same_offset: false does not start'
206
180
" with '--'. All flags must be in the format of"
207
- ' --flag_name=flag_value when using this format .' ,
181
+ ' --flag_name=flag_value.' ,
208
182
):
209
183
compile_options_util .generate_tpu_compilation_env (
210
184
xla_flags = [
0 commit comments