Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 192 additions & 3 deletions test/quantization/quantize_/workflows/float8/test_float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
from torch._inductor.utils import run_and_get_code
from torch.testing import FileCheck
from torch.testing._internal import common_utils
from torch.testing._internal.common_utils import (
run_tests,
)
from torch.testing._internal.common_utils import run_tests

from torchao.quantization import (
Float8DynamicActivationFloat8WeightConfig,
Expand Down Expand Up @@ -464,6 +462,197 @@ def test_index_select(self):
x_fp8.dequantize()[1], x_fp8_1.dequantize(), atol=0, rtol=0
)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize(
"sizes",
[
((128,), 256, 128),
((32, 128), 64, 256),
],
)
def test_unsqueeze_operation(self, granularity, sizes):
"""Test aten.unsqueeze.default operation on Float8Tensor"""
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"
M, N, K = sizes

# Create a linear layer and quantize it
linear = torch.nn.Linear(K, N, bias=False, dtype=dtype, device=device)
quantize_(linear, config)

original_weight = linear.weight
original_shape = original_weight.shape

# Test unsqueeze operation at dim=0 (only supported dimension)
unsqueezed_weight = original_weight.unsqueeze(0)

# Verify the unsqueezed tensor has correct shape
expected_shape = [1] + list(original_shape)
self.assertEqual(unsqueezed_weight.shape, torch.Size(expected_shape))

# Verify qdata and scale shapes
expected_qdata_shape = [1] + list(original_weight.qdata.shape)
expected_scale_shape = [1] + list(original_weight.scale.shape)

self.assertEqual(
unsqueezed_weight.qdata.shape, torch.Size(expected_qdata_shape)
)
self.assertEqual(
unsqueezed_weight.scale.shape, torch.Size(expected_scale_shape)
)

# Verify block_size is correctly updated
expected_block_size = []
for i in range(len(expected_shape)):
expected_block_size.append(expected_shape[i] // expected_scale_shape[i])

self.assertEqual(unsqueezed_weight.block_size, expected_block_size)

# Test that metadata is preserved
self.assertEqual(unsqueezed_weight.mm_config, original_weight.mm_config)
self.assertEqual(
unsqueezed_weight.act_quant_kwargs, original_weight.act_quant_kwargs
)
self.assertEqual(
unsqueezed_weight.kernel_preference, original_weight.kernel_preference
)
self.assertEqual(unsqueezed_weight.dtype, original_weight.dtype)

# Test numerical correctness
original_dequant = original_weight.dequantize()
unsqueezed_dequant = unsqueezed_weight.dequantize()
expected_dequant = original_dequant.unsqueeze(0)

self.assertEqual(unsqueezed_dequant, expected_dequant)

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
def test_unsqueeze_error_cases(self, granularity):
"""Test error cases for aten.unsqueeze.default operation"""
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"

# Create a linear layer and quantize it
linear = torch.nn.Linear(128, 256, bias=False, dtype=dtype, device=device)
quantize_(linear, config)

weight = linear.weight

# Test that unsqueezing on unsupported dimensions raises an error
with self.assertRaisesRegex(AssertionError, "Only dim == 0 is supported"):
weight.unsqueeze(1) # dim=1 should not be supported

@common_utils.parametrize("granularity", [PerTensor(), PerRow()])
@common_utils.parametrize("slice_dim", [0, 1, 2])
@common_utils.parametrize(
"tensor_shape",
[
(8, 128, 256), # 3D tensor: batch, seq_len, hidden_dim
(4, 64, 128), # smaller 3D tensor
],
)
def test_slice_3d_operation(self, granularity, slice_dim, tensor_shape):
"""Test slicing operations on 3D Float8Tensor across all dimensions"""
config = Float8DynamicActivationFloat8WeightConfig(granularity=granularity)
dtype = torch.bfloat16
device = "cuda"

B, S, H = tensor_shape

# Create a 3D tensor and quantize it (simulating a batched weight tensor)
original_tensor = torch.randn(B, S, H, dtype=dtype, device=device)

# Create Float8Tensor from the 3D high-precision tensor
float8_tensor = Float8Tensor.from_hp(
original_tensor,
granularity=granularity,
mm_config=config.mm_config,
)

slice_size = tensor_shape[slice_dim]
start_idx = 1
end_idx = slice_size - 1

# Perform slicing on the specified dimension
if slice_dim == 0:
sliced_tensor = float8_tensor[start_idx:end_idx, :, :]
expected_qdata = float8_tensor.qdata[start_idx:end_idx, :, :]
expected_scale = float8_tensor.scale[start_idx:end_idx, :]
elif slice_dim == 1:
sliced_tensor = float8_tensor[:, start_idx:end_idx, :]
expected_qdata = float8_tensor.qdata[:, start_idx:end_idx, :]
expected_scale = float8_tensor.scale[:, start_idx:end_idx]
elif slice_dim == 2:
sliced_tensor = float8_tensor[:, :, start_idx:end_idx]
expected_qdata = float8_tensor.qdata[:, :, start_idx:end_idx]
expected_scale = float8_tensor.scale[:, :]

if isinstance(granularity, PerTensor):
# Per-tensor quantization: scale should remain scalar
expected_scale = float8_tensor.scale

# Verify the sliced tensor shape
expected_shape = list(tensor_shape)
expected_shape[slice_dim] = end_idx - start_idx
self.assertEqual(sliced_tensor.shape, torch.Size(expected_shape))

# Verify qdata shape matches
self.assertEqual(sliced_tensor.qdata.shape, torch.Size(expected_shape))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and value of qdata as well, similar to

self.assertEqual(sliced_tensor.qdata, expected_qdata)

# Verify scale shape is correct based on granularity and slice dimension
Copy link
Contributor

@jerryzh168 jerryzh168 Oct 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we verify the value of scale as well

if isinstance(granularity, PerTensor):
# Per-tensor quantization: scale should remain scalar
self.assertEqual(sliced_tensor.scale.numel(), 1)
else:
# Per-row quantization: scale shape depends on which dimension we sliced
if slice_dim == 0:
# Slicing batch dimension affects scale
expected_scale_shape = list(float8_tensor.scale.shape)
expected_scale_shape[0] = end_idx - start_idx
self.assertEqual(
sliced_tensor.scale.shape, torch.Size(expected_scale_shape)
)
elif slice_dim == 1:
# Slicing sequence dimension affects scale
expected_scale_shape = list(float8_tensor.scale.shape)
expected_scale_shape[1] = end_idx - start_idx
self.assertEqual(
sliced_tensor.scale.shape, torch.Size(expected_scale_shape)
)
else:
# Slicing hidden dimension (dim=2) typically doesn't affect scale in per-row quantization
self.assertEqual(sliced_tensor.scale.shape, float8_tensor.scale.shape)

self.assertEqual(sliced_tensor.scale, expected_scale)

# Verify block_size is correctly updated
self.assertEqual(len(sliced_tensor.block_size), len(expected_shape))
for i in range(len(expected_shape)):
expected_block_dim = min(float8_tensor.block_size[i], expected_shape[i])
self.assertEqual(sliced_tensor.block_size[i], expected_block_dim)

# Test that metadata is preserved
self.assertEqual(sliced_tensor.mm_config, float8_tensor.mm_config)
self.assertEqual(sliced_tensor.act_quant_kwargs, float8_tensor.act_quant_kwargs)
self.assertEqual(
sliced_tensor.kernel_preference, float8_tensor.kernel_preference
)
self.assertEqual(sliced_tensor.dtype, float8_tensor.dtype)

# Test numerical correctness by comparing dequantized results
original_dequantized = float8_tensor.dequantize()
if slice_dim == 0:
sliced_original = original_dequantized[start_idx:end_idx, :, :]
elif slice_dim == 1:
sliced_original = original_dequantized[:, start_idx:end_idx, :]
elif slice_dim == 2:
sliced_original = original_dequantized[:, :, start_idx:end_idx]
sliced_dequantized = sliced_tensor.dequantize()

self.assertEqual(sliced_dequantized, sliced_original)


common_utils.instantiate_parametrized_tests(TestFloat8Tensor)

Expand Down
38 changes: 31 additions & 7 deletions torchao/quantization/quantize_/workflows/float8/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,10 +418,10 @@ def _(func, types, args, kwargs):

@implements(aten.slice.Tensor)
def _(func, types, args, kwargs):
"""Only supports slicing for dim == 1 and dim == 2
original tensor shape has dimension (N, K)
qdata has dimension (N, K)
scale (per row quantization) has dimension: (N,)
"""Supports slicing for 1d, 2d, and 3d tensors
original tensor shape has dimension (N, K), or (E, N, K)
qdata has dimension (N, K) or (E, N, K)
scale (per row quantization) has dimension: (N,) or (E, N)

since qdata has the same dimension as original tensor, we can directly slice that
for scale, we'll do a slice when dim is 0, and don't need to do anything for dim 1
Expand All @@ -431,12 +431,14 @@ def _(func, types, args, kwargs):
"""
self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1])
assert step == 1
assert dim == 0 or dim == 1, f"Only dim==0 or 1 are supported, got: {dim}"
assert dim == 0 or dim == 1 or dim == 2, (
f"Only dim==0,1,2 are supported, got: dim={dim}"
)
if end >= self.shape[dim]:
end = self.shape[dim]

assert self.qdata.ndim == 2, (
f"Expected packed weight to have dim 2, got {self.qdata.dim}"
assert self.qdata.ndim == 2 or self.qdata.ndim == 3, (
f"Expected packed weight to have dim==2,3 got: dim={self.qdata.ndim}"
)

# Always slice the qdata
Expand Down Expand Up @@ -639,6 +641,28 @@ def _(func, types, args, kwargs):
return return_and_correct_aliasing(func, args, kwargs, new_float8_tensor)


@implements(aten.unsqueeze.default)
def _(func, types, args, kwargs):
self, dim = args
assert dim == 0, f"Only dim == 0 is supported, got: {dim}"
qdata = self.qdata.unsqueeze(dim=dim)
scale = self.scale.unsqueeze(dim=dim)
block_size = []
for i in range(len(qdata.shape)):
block_size.append(qdata.shape[i] // scale.shape[i])

new = self.__class__(
qdata,
scale,
block_size,
self.mm_config,
self.act_quant_kwargs,
self.kernel_preference,
self.dtype,
)
return return_and_correct_aliasing(func, args, kwargs, new)


Float8Tensor.__module__ = "torchao.quantization"

# Allow a model with Float8Tensor weights to be loaded with `weights_only=True`
Expand Down
Loading