diff --git a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py index ba88b2e648..fc3fb5bb0a 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_tensor.py @@ -13,9 +13,7 @@ from torch._inductor.utils import run_and_get_code from torch.testing import FileCheck from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import ( - run_tests, -) +from torch.testing._internal.common_utils import run_tests from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, @@ -464,6 +462,197 @@ def test_index_select(self): x_fp8.dequantize()[1], x_fp8_1.dequantize(), atol=0, rtol=0 ) + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize( + "sizes", + [ + ((128,), 256, 128), + ((32, 128), 64, 256), + ], + ) + def test_unsqueeze_operation(self, granularity, sizes): + """Test aten.unsqueeze.default operation on Float8Tensor""" + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + dtype = torch.bfloat16 + device = "cuda" + M, N, K = sizes + + # Create a linear layer and quantize it + linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device) + quantize_(linear, config) + + original_weight = linear.weight + original_shape = original_weight.shape + + # Test unsqueeze operation at dim=0 (only supported dimension) + unsqueezed_weight = original_weight.unsqueeze(0) + + # Verify the unsqueezed tensor has correct shape + expected_shape = [1] + list(original_shape) + self.assertEqual(unsqueezed_weight.shape, torch.Size(expected_shape)) + + # Verify qdata and scale shapes + expected_qdata_shape = [1] + list(original_weight.qdata.shape) + expected_scale_shape = [1] + list(original_weight.scale.shape) + + self.assertEqual( + unsqueezed_weight.qdata.shape, torch.Size(expected_qdata_shape) + ) + self.assertEqual( + unsqueezed_weight.scale.shape, torch.Size(expected_scale_shape) + ) + + # Verify block_size is correctly updated + expected_block_size = [] + for i in range(len(expected_shape)): + expected_block_size.append(expected_shape[i] // expected_scale_shape[i]) + + self.assertEqual(unsqueezed_weight.block_size, expected_block_size) + + # Test that metadata is preserved + self.assertEqual(unsqueezed_weight.mm_config, original_weight.mm_config) + self.assertEqual( + unsqueezed_weight.act_quant_kwargs, original_weight.act_quant_kwargs + ) + self.assertEqual( + unsqueezed_weight.kernel_preference, original_weight.kernel_preference + ) + self.assertEqual(unsqueezed_weight.dtype, original_weight.dtype) + + # Test numerical correctness + original_dequant = original_weight.dequantize() + unsqueezed_dequant = unsqueezed_weight.dequantize() + expected_dequant = original_dequant.unsqueeze(0) + + self.assertEqual(unsqueezed_dequant, expected_dequant) + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + def test_unsqueeze_error_cases(self, granularity): + """Test error cases for aten.unsqueeze.default operation""" + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + dtype = torch.bfloat16 + device = "cuda" + + # Create a linear layer and quantize it + linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device) + quantize_(linear, config) + + weight = linear.weight + + # Test that unsqueezing on unsupported dimensions raises an error + with self.assertRaisesRegex(AssertionError, "Only dim == 0 is supported"): + weight.unsqueeze(1) # dim=1 should not be supported + + @common_utils.parametrize("granularity", [PerTensor(), PerRow()]) + @common_utils.parametrize("slice_dim", [0, 1, 2]) + @common_utils.parametrize( + "tensor_shape", + [ + (8, 128, 256), # 3D tensor: batch, seq_len, hidden_dim + (4, 64, 128), # smaller 3D tensor + ], + ) + def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape): + """Test slicing operations on 3D Float8Tensor across all dimensions""" + config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity) + dtype = torch.bfloat16 + device = "cuda" + + B, S, H = tensor_shape + + # Create a 3D tensor and quantize it (simulating a batched weight tensor) + original_tensor = torch.randn(B, S, H, dtype=dtype, device=device) + + # Create Float8Tensor from the 3D high-precision tensor + float8_tensor = Float8Tensor.from_hp( + original_tensor, + granularity=granularity, + mm_config=config.mm_config, + ) + + slice_size = tensor_shape[slice_dim] + start_idx = 1 + end_idx = slice_size - 1 + + # Perform slicing on the specified dimension + if slice_dim == 0: + sliced_tensor = float8_tensor[start_idx:end_idx, :, :] + expected_qdata = float8_tensor.qdata[start_idx:end_idx, :, :] + expected_scale = float8_tensor.scale[start_idx:end_idx, :] + elif slice_dim == 1: + sliced_tensor = float8_tensor[:, start_idx:end_idx, :] + expected_qdata = float8_tensor.qdata[:, start_idx:end_idx, :] + expected_scale = float8_tensor.scale[:, start_idx:end_idx] + elif slice_dim == 2: + sliced_tensor = float8_tensor[:, :, start_idx:end_idx] + expected_qdata = float8_tensor.qdata[:, :, start_idx:end_idx] + expected_scale = float8_tensor.scale[:, :] + + if isinstance(granularity, PerTensor): + # Per-tensor quantization: scale should remain scalar + expected_scale = float8_tensor.scale + + # Verify the sliced tensor shape + expected_shape = list(tensor_shape) + expected_shape[slice_dim] = end_idx - start_idx + self.assertEqual(sliced_tensor.shape, torch.Size(expected_shape)) + + # Verify qdata shape matches + self.assertEqual(sliced_tensor.qdata.shape, torch.Size(expected_shape)) + self.assertEqual(sliced_tensor.qdata, expected_qdata) + + # Verify scale shape is correct based on granularity and slice dimension + if isinstance(granularity, PerTensor): + # Per-tensor quantization: scale should remain scalar + self.assertEqual(sliced_tensor.scale.numel(), 1) + else: + # Per-row quantization: scale shape depends on which dimension we sliced + if slice_dim == 0: + # Slicing batch dimension affects scale + expected_scale_shape = list(float8_tensor.scale.shape) + expected_scale_shape[0] = end_idx - start_idx + self.assertEqual( + sliced_tensor.scale.shape, torch.Size(expected_scale_shape) + ) + elif slice_dim == 1: + # Slicing sequence dimension affects scale + expected_scale_shape = list(float8_tensor.scale.shape) + expected_scale_shape[1] = end_idx - start_idx + self.assertEqual( + sliced_tensor.scale.shape, torch.Size(expected_scale_shape) + ) + else: + # Slicing hidden dimension (dim=2) typically doesn't affect scale in per-row quantization + self.assertEqual(sliced_tensor.scale.shape, float8_tensor.scale.shape) + + self.assertEqual(sliced_tensor.scale, expected_scale) + + # Verify block_size is correctly updated + self.assertEqual(len(sliced_tensor.block_size), len(expected_shape)) + for i in range(len(expected_shape)): + expected_block_dim = min(float8_tensor.block_size[i], expected_shape[i]) + self.assertEqual(sliced_tensor.block_size[i], expected_block_dim) + + # Test that metadata is preserved + self.assertEqual(sliced_tensor.mm_config, float8_tensor.mm_config) + self.assertEqual(sliced_tensor.act_quant_kwargs, float8_tensor.act_quant_kwargs) + self.assertEqual( + sliced_tensor.kernel_preference, float8_tensor.kernel_preference + ) + self.assertEqual(sliced_tensor.dtype, float8_tensor.dtype) + + # Test numerical correctness by comparing dequantized results + original_dequantized = float8_tensor.dequantize() + if slice_dim == 0: + sliced_original = original_dequantized[start_idx:end_idx, :, :] + elif slice_dim == 1: + sliced_original = original_dequantized[:, start_idx:end_idx, :] + elif slice_dim == 2: + sliced_original = original_dequantized[:, :, start_idx:end_idx] + sliced_dequantized = sliced_tensor.dequantize() + + self.assertEqual(sliced_dequantized, sliced_original) + common_utils.instantiate_parametrized_tests(TestFloat8Tensor) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 46d4ca5426..ff89bbe576 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -418,10 +418,10 @@ def _(func, types, args, kwargs): @implements(aten.slice.Tensor) def _(func, types, args, kwargs): - """Only supports slicing for dim == 1 and dim == 2 - original tensor shape has dimension (N, K) - qdata has dimension (N, K) - scale (per row quantization) has dimension: (N,) + """Supports slicing for 1d, 2d, and 3d tensors + original tensor shape has dimension (N, K), or (E, N, K) + qdata has dimension (N, K) or (E, N, K) + scale (per row quantization) has dimension: (N,) or (E, N) since qdata has the same dimension as original tensor, we can directly slice that for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1 @@ -431,12 +431,14 @@ def _(func, types, args, kwargs): """ self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) assert step == 1 - assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}" + assert dim == 0 or dim == 1 or dim == 2, ( + f"Only dim==0,1,2 are supported, got: dim={dim}" + ) if end >= self.shape[dim]: end = self.shape[dim] - assert self.qdata.ndim == 2, ( - f"Expected packed weight to have dim 2, got {self.qdata.dim}" + assert self.qdata.ndim == 2 or self.qdata.ndim == 3, ( + f"Expected packed weight to have dim==2,3 got: dim={self.qdata.ndim}" ) # Always slice the qdata @@ -639,6 +641,28 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new_float8_tensor) +@implements(aten.unsqueeze.default) +def _(func, types, args, kwargs): + self, dim = args + assert dim == 0, f"Only dim == 0 is supported, got: {dim}" + qdata = self.qdata.unsqueeze(dim=dim) + scale = self.scale.unsqueeze(dim=dim) + block_size = [] + for i in range(len(qdata.shape)): + block_size.append(qdata.shape[i] // scale.shape[i]) + + new = self.__class__( + qdata, + scale, + block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + self.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + Float8Tensor.__module__ = "torchao.quantization" # Allow a model with Float8Tensor weights to be loaded with `weights_only=True`