-
Notifications
You must be signed in to change notification settings - Fork 344
[CPU] Add Float8OpaqueTensor for dynamic float8 act float8 weight #2505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Closed
Closed
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
736e1f1
[CPU] Add layout and implementation for dynamic float8 act float8 wei…
Xia-Weiwen 5cc5bcc
Merge branch 'main' into float8_da8w8
Xia-Weiwen c238385
Refine code
Xia-Weiwen 3e7d179
refine comments
Xia-Weiwen 953ac13
Check K % num_groups == 0
Xia-Weiwen cd53802
Check N & K % 32 == 0; update UT
Xia-Weiwen 0598012
Merge branch 'main' into float8_da8w8
Xia-Weiwen b4f6520
[CPU] Add float8OpaqueTensor for daf8wf8
Xia-Weiwen afcee6b
Update kernel implementation
Xia-Weiwen 6e2db8e
Merge branch 'main' into float8_da8w8
Xia-Weiwen 01c47ca
Fix issues in code
Xia-Weiwen a42d10f
Move cpp file
Xia-Weiwen cca4141
Update torchao/quantization/quantize_/workflows/float8/float8_opaque_…
Xia-Weiwen 4c79112
Refine code
Xia-Weiwen f286558
Apply suggestions from code review
Xia-Weiwen 2c58325
Merge branch 'main' into float8_da8w8
Xia-Weiwen 704e156
Refine kernel code
Xia-Weiwen 9858a42
Merge branch 'main' into float8_da8w8
Xia-Weiwen 8beeb03
Allocate buffer outside micro gemm kernel
Xia-Weiwen de9a931
Merge branch 'main' into float8_da8w8
Xia-Weiwen 5e75764
packing_format --> float8_packing_format
Xia-Weiwen File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
259 changes: 259 additions & 0 deletions
259
test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,259 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD 3-Clause license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import tempfile | ||
import unittest | ||
|
||
import torch | ||
from torch.testing._internal import common_utils | ||
from torch.testing._internal.common_utils import ( | ||
TestCase, | ||
run_tests, | ||
) | ||
|
||
from torchao import quantize_ | ||
from torchao.quantization import PerGroup, PerRow, PerTensor | ||
from torchao.quantization.quant_api import ( | ||
Float8DynamicActivationFloat8WeightConfig, | ||
) | ||
from torchao.quantization.utils import compute_error | ||
from torchao.utils import ( | ||
torch_version_at_least, | ||
) | ||
|
||
|
||
def get_config(granularity): | ||
return Float8DynamicActivationFloat8WeightConfig( | ||
activation_dtype=torch.float8_e4m3fn, | ||
granularity=granularity, | ||
float8_packing_format="opaque", | ||
) | ||
|
||
|
||
class ToyLinearModel(torch.nn.Module): | ||
def __init__(self, K=64, N=32, bias=False): | ||
super().__init__() | ||
self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) | ||
self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) | ||
|
||
def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): | ||
return ( | ||
torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device) | ||
* 0.1, | ||
) | ||
|
||
def forward(self, x): | ||
x = self.linear1(x) | ||
x = self.linear2(x) | ||
return x | ||
|
||
|
||
class TestDynamicFloat8Linear(TestCase): | ||
@unittest.skipIf( | ||
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
reason="cpp kernels not built", | ||
) | ||
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) | ||
@common_utils.parametrize("x_dim", [2, 3]) | ||
@common_utils.parametrize("bias", [True, False]) | ||
@common_utils.parametrize("bs", [1, 160]) | ||
def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs): | ||
device = "cpu" | ||
Xia-Weiwen marked this conversation as resolved.
Show resolved
Hide resolved
|
||
m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) | ||
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
if x_dim == 3: | ||
example_inputs = (example_inputs[0].unsqueeze(0),) | ||
y = m(*example_inputs) | ||
|
||
with torch.no_grad(): | ||
quantize_( | ||
m, | ||
get_config(PerRow()), | ||
) | ||
y1 = m(*example_inputs) | ||
assert compute_error(y, y1) > 20 | ||
y2, code = torch._inductor.utils.run_and_get_code( | ||
torch.compile(m, fullgraph=True, dynamic=True), | ||
*example_inputs, | ||
) | ||
# ensure the expected op is in the code | ||
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
assert compute_error(y, y2) > 20 | ||
|
||
@unittest.skipIf( | ||
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
reason="cpp kernels not built", | ||
) | ||
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
@common_utils.parametrize( | ||
"granularity", | ||
[ | ||
(PerTensor(), PerTensor()), | ||
(PerTensor(), PerRow()), | ||
(PerTensor(), PerGroup(64)), | ||
], | ||
) | ||
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) | ||
@common_utils.parametrize("x_dim", [2, 3]) | ||
@common_utils.parametrize("bias", [True, False]) | ||
@common_utils.parametrize("bs", [1, 128]) | ||
def test_dynamic_float8_linear_per_tensor_cpu( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is per tensor activation? might be good to clarify |
||
self, granularity, dtype, x_dim, bias, bs | ||
): | ||
device = "cpu" | ||
m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) | ||
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
if x_dim == 3: | ||
example_inputs = (example_inputs[0].unsqueeze(0),) | ||
y = m(*example_inputs) | ||
|
||
with torch.no_grad(): | ||
quantize_( | ||
m, | ||
get_config(granularity), | ||
) | ||
y1 = m(*example_inputs) | ||
assert compute_error(y, y1) > 20 | ||
y2, code = torch._inductor.utils.run_and_get_code( | ||
torch.compile(m, fullgraph=True, dynamic=True), | ||
*example_inputs, | ||
) | ||
# ensure the expected op is in the code | ||
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
assert compute_error(y, y2) > 20 | ||
|
||
@unittest.skipIf( | ||
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
reason="cpp kernels not built", | ||
) | ||
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
@common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) | ||
@common_utils.parametrize("x_dim", [2, 3]) | ||
@common_utils.parametrize("bias", [True, False]) | ||
@common_utils.parametrize("bs", [4, 128]) | ||
def test_dynamic_float8_linear_ref_cpu(self, dtype, x_dim, bias, bs): | ||
device = "cpu" | ||
# the shape is not supported by cpp kernel, so the ref path will be used. | ||
m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) | ||
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
if x_dim == 3: | ||
example_inputs = (example_inputs[0].unsqueeze(0),) | ||
y = m(*example_inputs) | ||
|
||
with torch.no_grad(): | ||
quantize_( | ||
m, | ||
get_config(PerRow()), | ||
) | ||
y1 = m(*example_inputs) | ||
assert compute_error(y, y1) > 20 | ||
y2, code = torch._inductor.utils.run_and_get_code( | ||
torch.compile(m, fullgraph=True, dynamic=True), | ||
*example_inputs, | ||
) | ||
# ensure the expected op is in the code | ||
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
assert compute_error(y, y2) > 20 | ||
|
||
@unittest.skipIf( | ||
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
reason="cpp kernels not built", | ||
) | ||
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
@common_utils.parametrize("dtype", [torch.bfloat16, torch.half]) | ||
@common_utils.parametrize("x_dim", [2, 3]) | ||
@common_utils.parametrize("bias", [True, False]) | ||
@common_utils.parametrize("bs", [1, 160]) | ||
@common_utils.parametrize("group_size", [32, 64, 128]) | ||
def test_dynamic_float8_linear_per_group_cpu( | ||
self, dtype, x_dim, bias, bs, group_size | ||
): | ||
device = "cpu" | ||
m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) | ||
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
if x_dim == 3: | ||
example_inputs = (example_inputs[0].unsqueeze(0),) | ||
y = m(*example_inputs) | ||
|
||
with torch.no_grad(): | ||
quantize_( | ||
m, | ||
get_config([PerRow(), PerGroup(group_size)]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are these tests not combined into the same one? seems all of them are very similar |
||
) | ||
y1 = m(*example_inputs) | ||
assert compute_error(y, y1) > 20 | ||
y2, code = torch._inductor.utils.run_and_get_code( | ||
torch.compile(m, fullgraph=True, dynamic=True), | ||
*example_inputs, | ||
) | ||
# ensure the expected op is in the code | ||
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
assert compute_error(y, y2) > 20 | ||
|
||
@unittest.skipIf( | ||
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
reason="cpp kernels not built", | ||
) | ||
@unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") | ||
@common_utils.parametrize("dtype", [torch.bfloat16, torch.half]) | ||
@common_utils.parametrize("x_dim", [2, 3]) | ||
@common_utils.parametrize("bias", [True, False]) | ||
@common_utils.parametrize("bs", [1, 160]) | ||
@common_utils.parametrize("group_size", [32, 64, 128]) | ||
def test_dynamic_float8_linear_per_group_act_cpu( | ||
self, dtype, x_dim, bias, bs, group_size | ||
): | ||
device = "cpu" | ||
m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) | ||
example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) | ||
if x_dim == 3: | ||
example_inputs = (example_inputs[0].unsqueeze(0),) | ||
y = m(*example_inputs) | ||
|
||
with torch.no_grad(): | ||
quantize_( | ||
m, | ||
get_config([PerGroup(group_size), PerGroup(group_size)]), | ||
) | ||
y1 = m(*example_inputs) | ||
assert compute_error(y, y1) > 20 | ||
y2, code = torch._inductor.utils.run_and_get_code( | ||
torch.compile(m, fullgraph=True, dynamic=True), | ||
*example_inputs, | ||
) | ||
# ensure the expected op is in the code | ||
assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] | ||
assert compute_error(y, y2) > 20 | ||
|
||
@unittest.skipIf( | ||
"CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), | ||
reason="cpp kernels not built", | ||
) | ||
@common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) | ||
def test_module_path(self, dtype): | ||
linear = torch.nn.Linear(128, 256, dtype=dtype) | ||
quantize_(linear, get_config(PerRow())) | ||
self.assertEqual( | ||
str(type(linear.weight)), | ||
"<class 'torchao.quantization.Float8OpaqueTensor'>", | ||
) | ||
|
||
with tempfile.NamedTemporaryFile() as f: | ||
torch.save(linear.state_dict(), f) | ||
f.seek(0) | ||
state_dict = torch.load(f) | ||
self.assertEqual( | ||
str(type(state_dict["weight"])), | ||
"<class 'torchao.quantization.Float8OpaqueTensor'>", | ||
) | ||
|
||
|
||
common_utils.instantiate_parametrized_tests(TestDynamicFloat8Linear) | ||
|
||
|
||
if __name__ == "__main__": | ||
run_tests() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit:
TestFloat8OpaqueTensor
?