diff --git a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py new file mode 100644 index 0000000000..99acd5ed82 --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao import quantize_ +from torchao.quantization import PerGroup, PerRow, PerTensor +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + + +def get_config(granularity): + return Float8DynamicActivationFloat8WeightConfig( + activation_dtype=torch.float8_e4m3fn, + granularity=granularity, + float8_packing_format="opaque", + ) + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, K=64, N=32, bias=False): + super().__init__() + self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) + self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) + + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): + return ( + torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device) + * 0.1, + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestDynamicFloat8Linear(TestCase): + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config(PerRow()), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize( + "granularity", + [ + (PerTensor(), PerTensor()), + (PerTensor(), PerRow()), + (PerTensor(), PerGroup(64)), + ], + ) + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 128]) + def test_dynamic_float8_linear_per_tensor_cpu( + self, granularity, dtype, x_dim, bias, bs + ): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config(granularity), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [4, 128]) + def test_dynamic_float8_linear_ref_cpu(self, dtype, x_dim, bias, bs): + device = "cpu" + # the shape is not supported by cpp kernel, so the ref path will be used. + m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config(PerRow()), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + @common_utils.parametrize("group_size", [32, 64, 128]) + def test_dynamic_float8_linear_per_group_cpu( + self, dtype, x_dim, bias, bs, group_size + ): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config([PerRow(), PerGroup(group_size)]), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + @common_utils.parametrize("group_size", [32, 64, 128]) + def test_dynamic_float8_linear_per_group_act_cpu( + self, dtype, x_dim, bias, bs, group_size + ): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config([PerGroup(group_size), PerGroup(group_size)]), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + def test_module_path(self, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype) + quantize_(linear, get_config(PerRow())) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +common_utils.instantiate_parametrized_tests(TestDynamicFloat8Linear) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp new file mode 100644 index 0000000000..0301568477 --- /dev/null +++ b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp @@ -0,0 +1,603 @@ +#include +#include +#include +#include +#include "utils.h" +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace torchao { + +namespace { + +#define BLOCK_N 32 + +#define PER_TENSOR 1 +#define PER_ROW 2 +#define PER_GROUP 3 + +static bool cpublas_checked = false; +static bool cpublas_can_pack = false; + +bool cpublas_could_pack() { + // the could_pack check requires AMX support implicitly + if (cpublas_checked) { + return cpublas_can_pack; + } +#ifdef CPUBLAS_BRGEMM_F8F8F32 + cpublas_can_pack = at::native::cpublas::could_pack(at::kFloat8_e4m3fn); +#else + cpublas_can_pack = at::native::cpublas::could_pack(at::kBFloat16); +#endif + cpublas_checked = true; + return cpublas_can_pack; +} + +/* +return: packed_weight, packed_scales +*/ +std::tuple +float8_linear_prepack_impl( + const at::Tensor& weight, + const at::Tensor& scales) { + // weight shape = [N, K] + // scales shape = [N, G] + TORCH_CHECK(weight.dim() == 2, + "Float8 linear CPU: Weight should be a 2D tensor for packing"); + int N = weight.size(0); + int K = weight.size(1); + constexpr int block_n = BLOCK_N; + // Case to fall back + if (N % block_n != 0 || K % 32 != 0) { + return std::make_tuple(weight, scales); + } + + auto new_scales = scales; + bool is_per_tensor = new_scales.numel() == 1; + if (new_scales.dim() == 1 && !is_per_tensor) { + new_scales.unsqueeze_(1); + } + new_scales = new_scales.to(at::kFloat); + int G = is_per_tensor ? 1 : new_scales.size(1); + TORCH_CHECK(K % G == 0, "K should be divisible by num_groups"); + int group_size = K / G; + int block_k = group_size > 128 ? 128 : group_size; + while (K % block_k != 0) { + block_k /= 2; + } + TORCH_CHECK(block_k > 0 && block_k <= group_size, + "Float8 linear CPU: Invalid block_k size, should be in (0, group_size]"); + int Nc = N / block_n; + int Kc = K / block_k; + + // Reorder weight to [N/block_n, K/block_k, block_k, block_n] + // Reorder scales to [N/block_n, G, block_n] + auto weight_view = weight.view({Nc, block_n, Kc, block_k}); + at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); + at::Tensor blocked_weight; + at::Tensor blocked_scales = is_per_tensor ? new_scales.view({1}) : new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + +#if defined(CPU_CAPABILITY_AVX512) + if (cpublas_could_pack()) { +#ifdef CPUBLAS_BRGEMM_F8F8F32 + constexpr int vnni_size = get_vnni_size(); // for fp8 +#else + constexpr int vnni_size = get_vnni_size(); // for bfloat16 +#endif + blocked_weight = at::empty({Nc, Kc, block_k, block_n}, weight.options()); + auto weight_ptr = reinterpret_cast(weight_reordered.data_ptr()); + auto blocked_weight_ptr = reinterpret_cast(blocked_weight.data_ptr()); + int64_t num_blocks = Nc * Kc; + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + auto in_ptr = weight_ptr + i * block_k * block_n; + auto out_ptr = blocked_weight_ptr + i * block_k * block_n; + + // Reorder weight block to VNNI + // plain shape = [block_k, block_n] + // packed shape = [block_k / VNNI_SIZE, block_n, VNNI_SIZE] viewed as [block_k, block_n] + constexpr int n_group_size = 8; + constexpr int n_group = block_n / n_group_size; // 4 + for (int nb = 0; nb < n_group; ++nb) { + for (int k = 0; k < block_k; k += vnni_size) { + for (int ni = 0; ni < n_group_size; ++ni) { + for (int ki = 0; ki < vnni_size; ++ki) { + int src_idx = nb * n_group_size + ni + (k + ki) * block_n; + int dst_idx = (nb * n_group_size + ni) * vnni_size + k * block_n + ki; + *(out_ptr + dst_idx) = *(in_ptr + src_idx); + } + } + } + } + } + }); + } else +#endif + { + blocked_weight = weight_reordered; + } + + return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales)); +} + +#if defined(CPU_CAPABILITY_AVX512) +// this doesn't handle NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); + + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); + + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); +} + +static void cvt_f8e4m3_to_bf16( + const at::Float8_e4m3fn* __restrict__ in, + at::BFloat16* out, + int64_t rows, + int64_t cols, + int64_t stride) { + if (stride == cols) { + // A contiguous buffer + size_t len = rows * cols; + size_t i = 0; + for (; i < len; i += 32) { + __m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[i]); + __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec); + _mm512_storeu_si512((__m512i*)(out + i), (__m512i)bf16_vec); + } + for (; i < len; ++i) { + out[i] = (at::BFloat16)in[i]; + } + } else { + // Non-contiguous. Access each row with stride + TORCH_CHECK(stride > cols); + for (int r = 0; r < rows; ++r) { + size_t i = 0; + size_t vec_len = cols / 32 * 32; + for (; i < vec_len; i += 32) { + __m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[r * stride + i]); + __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec); + _mm512_storeu_si512((__m512i*)(out + r * cols + i), (__m512i)bf16_vec); + } + for (; i < cols; ++i) { + out[r * cols + i] = (at::BFloat16)in[r * stride + i]; + } + } + } +} + + +// accumulate and store result to buffer +// if act/wei are per_group quantized, apply scales +template +static void _accumulate_result( + float* __restrict__ output, + const float* __restrict__ input, + const float* __restrict__ scale_a, + const float* __restrict__ scale_b, + int M, + int ldi, + int ldo, + int ldsa = 1) { + float a_scale, b_scale; + __m512 va_scale; + __m512 vb_scale; + for (int m = 0; m < M; ++m) { + if constexpr (act_quant_mode == PER_GROUP) { + a_scale = *(scale_a + m * ldsa); + va_scale = _mm512_set1_ps(a_scale); + } + constexpr int N_UNROLL = N / 16; + c10::ForcedUnroll{}([&](auto i) { + constexpr int n = i * 16; + __m512 vc_f = _mm512_loadu_ps(input + m * ldi + n); + if constexpr (act_quant_mode == PER_GROUP) { + vc_f = _mm512_mul_ps(vc_f, va_scale); + } + if constexpr (wei_quant_mode == PER_GROUP) { + vb_scale = _mm512_loadu_ps(scale_b + n); + vc_f = _mm512_mul_ps(vc_f, vb_scale); + } + if constexpr (accum) { + __m512 vo = _mm512_loadu_ps(output + m * ldo + n); + _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f)); + } else { + _mm512_storeu_ps(output + m * ldo + n, vc_f); + } + }); + constexpr int tail_start = N / 16 * 16; + for (int n = tail_start; n < N; ++n) { + float dq_val = input[m * ldi + n]; + if constexpr (act_quant_mode == PER_GROUP) { + dq_val = dq_val * a_scale; + } + if constexpr (wei_quant_mode == PER_GROUP) { + b_scale = scale_b[n]; + dq_val = dq_val * b_scale; + } + if constexpr (accum) { + output[m * ldo + n] += dq_val; + } else { + output[m * ldo + n] = dq_val; + } + } + } +} + +// Store result to output buffer with dtype conversion +// If act/wei are per_row or per_tensor quantized, apply scales +// If bias is not null, add bias +template +inline void store_out( + const float* y_buf, + out_dtype* c_ptr, + int64_t M, + int64_t lda, + const float* scales_a, + const float* scales_b, + const float* bias) { + float a_scale = 1.0, b_scale = 1.0; + __m512 va_scale, vb_scale; + if constexpr (act_quant_mode == PER_TENSOR) { + a_scale = *scales_a; + } + if constexpr (wei_quant_mode == PER_TENSOR) { + b_scale = *scales_b; + vb_scale = _mm512_set1_ps(b_scale); + } + for (int i = 0; i < M; ++i) { + if constexpr (act_quant_mode == PER_ROW) { + a_scale = *(scales_a + i); + } + if constexpr (act_quant_mode != PER_GROUP) { + va_scale = _mm512_set1_ps(a_scale); + } + constexpr int N_UNROLL = N / 16; + c10::ForcedUnroll{}([&](auto idx) { + constexpr int j = idx * 16; + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m512 bias_vec = bias ? _mm512_loadu_ps(bias + j) : _mm512_setzero_ps(); + if constexpr (act_quant_mode != PER_GROUP) { + y_vec = _mm512_mul_ps(y_vec, va_scale); + } + if constexpr (wei_quant_mode == PER_ROW) { + vb_scale = _mm512_loadu_ps(scales_b + j); + } + if constexpr (wei_quant_mode != PER_GROUP) { + y_vec = _mm512_mul_ps(y_vec, vb_scale); + } + y_vec = _mm512_add_ps(y_vec, bias_vec); + if constexpr (std::is_same::value) { + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } else if constexpr (std::is_same::value) { + __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); + } else if constexpr (std::is_same::value) { + __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); + } else { + TORCH_CHECK(false, "Unsupported output dtype"); + } + }); + constexpr int tail_start = N / 16 * 16; + for (int j = tail_start; j < N; ++j) { + if constexpr (wei_quant_mode == PER_ROW) { + b_scale = scales_b[j]; + } + c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); + } + } // for M +} + +#else // no AVX512 + +static void cvt_f8e4m3_to_bf16( + const at::Float8_e4m3fn* __restrict__ in, + at::BFloat16* out, + int64_t rows, + int64_t cols, + int64_t stride) { + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + out[r * cols + c] = (at::BFloat16)in[r * stride + c]; + } + } +} + +// Store result to output buffer with dtype conversion +// If act/wei are per_row or per_tensor quantized, apply scales +// If bias is not null, add bias +template +inline void store_out( + const float* y_buf, + out_dtype* c_ptr, + int64_t M, + int64_t lda, + const float* scales_a, + const float* scales_b, + const float* bias) { + float a_scale = 1.0, b_scale = 1.0; + if constexpr (act_quant_mode == PER_TENSOR) { + a_scale = *scales_a; + } + if constexpr (wei_quant_mode == PER_TENSOR) { + b_scale = *scales_b; + } + for (int i = 0; i < M; ++i) { + if constexpr (act_quant_mode == PER_ROW) { + a_scale = *(scales_a + i); + } + for (int j = 0; j < N; ++j) { + if constexpr (wei_quant_mode == PER_ROW) { + b_scale = scales_b[j]; + } + c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); + } + } // for M +} + +#endif // CPU_CAPABILITY_AVX512 + +template +void _micro_gemm( + float* C, + const at::Float8_e4m3fn* A, + const float* scales_a, + const at::Float8_e4m3fn* B, + const float* scales_b, + int64_t M, + int64_t K, + int64_t lda, + int64_t ldc, + int64_t ldsa, + float* ukernel_buf, + at::BFloat16* dqA_buf, + at::BFloat16* dqB_buf) { + // If FP8 brgemm is not available, convert A/B to bf16 for computation + // Compute GEMM fp8 * fp8 -> fp32 (or bf16 * bf16 -> fp32) + // If per_group quant, apply scales. Otherwise, don't apply scales here + // Finally accumulate and store results +#if defined(CPU_CAPABILITY_AVX512) + if constexpr (cpublas_can_pack) { +#ifdef CPUBLAS_BRGEMM_F8F8F32 + at::native::cpublas::brgemm( + M, + N, + K, + lda /*lda*/, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + A, + B, + ukernel_buf, + true /* is_vnni */); +#else + cvt_f8e4m3_to_bf16(A, dqA_buf, M, K, lda); + cvt_f8e4m3_to_bf16(B, dqB_buf, K, N, N); + at::native::cpublas::brgemm( + M, + N, + K, + K /*lda*/, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + dqA_buf, + dqB_buf, + ukernel_buf, + true /* is_vnni */); +#endif + _mm_prefetch(B + N * (K + 128), _MM_HINT_T0); + _mm_prefetch(A + K + 128, _MM_HINT_T0); + _accumulate_result( + C, + ukernel_buf, + scales_a, + scales_b, + M, + N /*ldi*/, + ldc, + ldsa); + } else +#endif + { + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0; + for (int64_t k = 0; k < K; ++k) { + sum += ((float)A[i * lda + k] * (float)B[k * N + j]); + } + if constexpr (act_quant_mode == PER_GROUP) { + sum *= scales_a[i * ldsa]; + } + if constexpr (wei_quant_mode == PER_GROUP) { + sum *= scales_b[j]; + } + C[i * ldc + j] += sum; + } + } + } +} + +template +void _float8_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const std::optional& bias, + at::Tensor& output) { + // input shape = [..., K] + // input is per token quantized + int64_t K = input.size(-1); + auto input_view = input.view({-1, K}); + int64_t M = input_view.size(0); + + // weight shape = [Nc, Kc, block_k, block_n] + // scales shape = [Nc, G, block_n] + int64_t Nc = weight.size(0); + int64_t Kc = weight.size(1); + int64_t block_k = weight.size(2); + constexpr int64_t block_n = BLOCK_N; + TORCH_CHECK(weight.size(3) == block_n, "Float8 linear: unexpected weight shape"); + int64_t N = Nc * block_n; + TORCH_CHECK(K == Kc * block_k, "Float8 linear: weight and input shapes mismatch"); + auto [parallel_on_M, block_m, Mc, Mc_parallel] = get_m_blocking(M); + int64_t num_parallel_blocks = Mc_parallel * Nc; + + // scales shape = [Nc, G, block_n] + int64_t num_groups = wei_quant_mode == PER_TENSOR ? 1 : weight_scales.size(1); + TORCH_CHECK(K % num_groups == 0, "K should be divisible by num_groups"); + int64_t group_size = K / num_groups; + TORCH_CHECK(group_size % block_k == 0, + "Float8 linear: group_size should be divisible by block_k"); + int64_t block_per_group = group_size / block_k; + TORCH_CHECK(input_scales.numel() == 1 || input_scales.numel() == M || input_scales.numel() == M * num_groups, "Float8 linear: unexpected input scales shape"); + auto ldsa = act_quant_mode == PER_TENSOR ? 0 : act_quant_mode == PER_ROW ? 1 : num_groups; + + const at::Float8_e4m3fn* a_ptr = input_view.data_ptr(); + const float* a_scales_ptr = input_scales.data_ptr(); + const at::Float8_e4m3fn* b_ptr = weight.data_ptr(); + const float* b_scales_ptr = weight_scales.data_ptr(); + out_dtype* c_ptr = output.data_ptr(); + const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; + + int64_t block_size = block_m * block_n; + int64_t num_thread = at::get_num_threads(); + at::Tensor y_buffer = at::empty({num_thread, block_size}, output.options().dtype(at::kFloat)); + // Create buffer for brgemm output and dqA/dqB (optional) +#if defined(CPU_CAPABILITY_AVX512) + // buffer for brgemm output in float32 + int64_t buffer_size = block_size * 2; // float32 = bfloat16 * 2 +#ifndef CPUBLAS_BRGEMM_F8F8F32 + // buffers for dqA & dqB in bf16 + buffer_size += (block_k * block_n + block_m * block_k); +#endif + at::Tensor micro_gemm_buffer = at::empty({num_thread, buffer_size}, output.options().dtype(at::kBFloat16)); +#endif + + at::parallel_for(0, num_parallel_blocks, 1, [&](int64_t begin, int64_t end) { + // Get the address of pre-allocated buffers + float* y_buf = y_buffer.data_ptr() + at::get_thread_num() * block_size; + at::BFloat16 *dqA_buffer = nullptr, *dqB_buffer = nullptr; + float* ukernel_buf = nullptr; +#if defined(CPU_CAPABILITY_AVX512) + at::BFloat16* micro_gemm_buf = micro_gemm_buffer.data_ptr() + at::get_thread_num() * buffer_size; + ukernel_buf = reinterpret_cast(micro_gemm_buf); +#ifndef CPUBLAS_BRGEMM_F8F8F32 + dqA_buffer = micro_gemm_buf; + dqB_buffer = micro_gemm_buf + block_m * block_k; + ukernel_buf = reinterpret_cast(micro_gemm_buf + block_m * block_k + block_k * block_n); +#endif +#endif + int64_t mc = 0, nc = 0; + at::native::data_index_init(begin, mc, Mc_parallel, nc, Nc); + for (const auto i : c10::irange(begin, end)) { + (void)i; // Suppress unused variable + int64_t mc_end = parallel_on_M ? mc + 1 : Mc; + + for (int mci = mc; mci < mc_end; ++mci) { + int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; + zero_buffer(y_buf, m_size * block_n); + for (int kci = 0; kci < Kc; ++kci) { + auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group; + auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n; + _micro_gemm( + /* C */ y_buf, + /* A */ a_ptr + mci * block_m * K + kci * block_k, + /* scales_a */ scales_a, + /* B */ b_ptr + (nc * Kc + kci) * block_n * block_k, + /* scales_b */ scales_b, + /* M */ m_size, + /* K */ block_k, + /* lda */ K, + /* ldc */ block_n, + /* ldsa */ ldsa, + /* ukernel_buf */ ukernel_buf, + /* dqA_buf */ dqA_buffer, + /* dqB_buf */ dqB_buffer); + } + // store y_buf to output with dtype conversion + auto scales_a = act_quant_mode == PER_TENSOR ? a_scales_ptr : + act_quant_mode == PER_ROW ? a_scales_ptr + mci * block_m : nullptr; + auto scales_b = wei_quant_mode == PER_TENSOR ? b_scales_ptr : + wei_quant_mode == PER_ROW ? b_scales_ptr + nc * block_n : nullptr; + auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; + store_out( + y_buf, + c_ptr + mci * block_m * N + nc * block_n, + m_size, + N /*lda*/, + scales_a, + scales_b, + bias_data); + } + at::native::data_index_step(mc, Mc_parallel, nc, Nc); + } + if constexpr (cpublas_can_pack) { + at::native::cpublas::brgemm_release(); + } + }); +} + +at::Tensor float8_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const std::optional& bias, + at::ScalarType output_dtype) { + int64_t N = weight.dim() == 4 ? weight.size(0) * weight.size(-1) : weight.size(0); + int act_quant_mode = input_scales.numel() == 1 ? PER_TENSOR : + input_scales.numel() == input.numel() / input.size(-1) ? PER_ROW : + PER_GROUP; + int wei_quant_mode = weight_scales.numel() == 1 ? PER_TENSOR : + weight_scales.numel() == N ? PER_ROW : + PER_GROUP; + // Case to fall back + if (weight.dim() == 2) { + TORCH_CHECK(act_quant_mode != PER_GROUP && wei_quant_mode != PER_GROUP, + "FP8 linear: Per-group quantization is not supported in the fallback path"); + auto y_fp32 = at::linear( + input.to(at::kFloat).mul_(input_scales), + weight.to(at::kFloat).mul_(weight_scales), + bias); + return y_fp32.to(output_dtype); + } + + static bool cpublas_can_pack = cpublas_could_pack(); + auto out_sizes = input.sizes().vec(); + out_sizes.back() = N; + auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); + + AT_DISPATCH_LINEAR_KERNEL(output_dtype, cpublas_can_pack, act_quant_mode, wei_quant_mode, [&](){ + _float8_linear_impl( + input, + input_scales, + weight, + weight_scales, + bias, + output); + }); + return output; +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::float8_linear_prepack_cpu", &float8_linear_prepack_impl); + m.impl("torchao::float8_linear_cpu", &float8_linear_impl); +} + +} // namespace torchao diff --git a/torchao/csrc/cpu/aten_kernels/utils.h b/torchao/csrc/cpu/aten_kernels/utils.h new file mode 100644 index 0000000000..0a238a316e --- /dev/null +++ b/torchao/csrc/cpu/aten_kernels/utils.h @@ -0,0 +1,111 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +int64_t get_m_block(int64_t M) { + if (M <= 48) { + return M; + } else if (M < 64) { + return 32; + } else if (M < 96) { + return 64; + } else { + return 128; + } +} + +std::tuple +get_m_blocking(int64_t M) { + bool parallel_on_M = M > 128; + int64_t block_m = get_m_block(M); + int64_t Mc = (M + block_m - 1) / block_m; + int64_t Mc_parallel = parallel_on_M ? Mc : 1; + return std::make_tuple(parallel_on_M, block_m, Mc, Mc_parallel); +} + +#if defined(CPU_CAPABILITY_AVX512) +template +void zero_buffer(T* data, int64_t size) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto zero_vec = at::vec::Vectorized(0); + int64_t d = 0; + for (; d < size - (size % vec_size); d += vec_size) { + zero_vec.store(data + d); + } + if (d < size) { + zero_vec.store(data + d, size - d); + } +} +#else +template +void zero_buffer(T* data, int64_t size) { + memset(data, 0, sizeof(T) * size); +} +#endif + +template struct vnni_traits; +template <> struct vnni_traits { static constexpr int size = 2; }; +template <> struct vnni_traits { static constexpr int size = 2; }; +template <> struct vnni_traits { static constexpr int size = 4; }; +template <> struct vnni_traits { static constexpr int size = 4; }; + +template constexpr int get_vnni_size() { return vnni_traits::size; } + + +// Utilities for dispatch +#define AT_DISPATCH_OUT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, out_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, out_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, out_t, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ENUM(VALUE, TYPE, HINT, ...) \ + case VALUE: { \ + constexpr TYPE HINT = VALUE; \ + __VA_ARGS__; \ + break; \ + } + +#define AT_DISPATCH_BOOL(VALUE, NAME, HINT, ...) \ + [&]() { \ + switch (VALUE) { \ + AT_DISPATCH_CASE_ENUM(true, bool, HINT, __VA_ARGS__) \ + AT_DISPATCH_CASE_ENUM(false, bool, HINT, __VA_ARGS__) \ + } \ + }() + +#define AT_DISPATCH_QUANT_MODE(MODE, NAME, HINT, ...) \ + [&]() { \ + switch (MODE) { \ + AT_DISPATCH_CASE_ENUM(PER_TENSOR, int, HINT, __VA_ARGS__) \ + AT_DISPATCH_CASE_ENUM(PER_ROW, int, HINT, __VA_ARGS__) \ + AT_DISPATCH_CASE_ENUM(PER_GROUP, int, HINT, __VA_ARGS__) \ + } \ + }() + +#define AT_DISPATCH_LINEAR_KERNEL(OUT_DTYPE, CAN_PACK, A_QUANT_MODE, B_QUANT_MODE, ...) \ + AT_DISPATCH_BOOL( \ + CAN_PACK, "cpublas_can_pack", can_pack, \ + AT_DISPATCH_QUANT_MODE( \ + A_QUANT_MODE, "act_quant_mode", a_quant_mode, \ + AT_DISPATCH_QUANT_MODE( \ + B_QUANT_MODE, "wei_quant_mode", b_quant_mode, \ + AT_DISPATCH_OUT_TYPES( \ + OUT_DTYPE, "out_dtype", __VA_ARGS__ \ + ) \ + ) \ + ) \ + ) diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f15d38576c..4fa69ecd78 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -14,6 +14,7 @@ from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul from torchao.float8.types import FP8Granularity from torchao.quantization.granularity import ( + PerGroup, PerRow, PerTensor, ) @@ -205,22 +206,19 @@ def _normalize_granularity( ] ], ) -> Tuple[FP8Granularity, FP8Granularity]: + supported_granularities = (PerTensor, PerRow, PerGroup) processed_granularity = None if granularity is None: processed_granularity = (PerTensor(), PerTensor()) - elif isinstance(granularity, (PerTensor, PerRow)): + elif isinstance(granularity, supported_granularities): processed_granularity = (granularity, granularity) elif isinstance(granularity, (tuple, list)) and len(granularity) == 2: if not ( - isinstance(granularity[0], (PerTensor, PerRow)) - and isinstance(granularity[1], (PerTensor, PerRow)) + isinstance(granularity[0], supported_granularities) + and isinstance(granularity[1], supported_granularities) ): raise ValueError( - f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported." - ) - if not isinstance(granularity[0], type(granularity[1])): - raise ValueError( - f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." + f"Invalid granularity types: {granularity}, only PerTensor or PerRow or PerGroup are supported." ) processed_granularity = tuple(granularity) else: @@ -232,6 +230,7 @@ def _normalize_granularity( def _check_hardware_support( granularities: Tuple[FP8Granularity, FP8Granularity], + device: str = "cuda", ) -> None: """ Validate that the hardware supports the requested granularities. @@ -243,12 +242,16 @@ def _check_hardware_support( AssertionError: If hardware doesn't support the requested granularity ValueError: If invalid granularity type is provided """ + supported_granularities = ( + (PerTensor, PerRow, PerGroup) if device == "cpu" else (PerTensor, PerRow) + ) for _granularity in granularities: - if not isinstance(_granularity, (PerTensor, PerRow)): + if not isinstance(_granularity, supported_granularities): raise ValueError( - f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported." + f"Invalid granularity type: {_granularity}, only {supported_granularities} are supported." ) + if device != "cpu": assert is_sm_at_least_89() or is_MI300(), ( "Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+." ) diff --git a/torchao/float8/types.py b/torchao/float8/types.py index b332a9629a..63cabc9582 100644 --- a/torchao/float8/types.py +++ b/torchao/float8/types.py @@ -12,8 +12,8 @@ from typing import TYPE_CHECKING, Union if TYPE_CHECKING: - from torchao.quantization.granularity import PerRow, PerTensor + from torchao.quantization.granularity import PerGroup, PerRow, PerTensor # Define FP8Granularity type alias to break circular import dependencies -FP8Granularity = Union["PerTensor", "PerRow"] +FP8Granularity = Union["PerTensor", "PerRow", "PerGroup"] diff --git a/torchao/ops.py b/torchao/ops.py index b6348f90a5..77475d8185 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -71,6 +71,12 @@ lib.define( "_scaled_embedding_bag(Tensor qweight, Tensor indices, Tensor offsets, Tensor weight_scale, float o_scale, int mode, bool include_last_offset) -> Tensor" ) +lib.define( + "float8_linear_prepack_cpu(Tensor weight, Tensor scales) -> (Tensor, Tensor)" +) +lib.define( + "float8_linear_cpu(Tensor input, Tensor input_scales, Tensor weight, Tensor weight_scales, Tensor? bias, ScalarType output_dtype) -> Tensor" +) def register_custom_op(name): @@ -1103,6 +1109,21 @@ def _( return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) +def float8_linear_prepack_cpu( + weight: Tensor, + scales: Tensor, +) -> Tensor: + """ + Prepack weights for float8 linear operator on CPU. + Args: + weight: weight tensor. + scales: scales for weight tensor. + Returns: + packed weight, packed scales + """ + return torch.ops.torchao.float8_linear_prepack_cpu.default(weight, scales) + + @register_custom_op("torchao::_scaled_embedding_bag") def _( qweight: Tensor, @@ -1117,3 +1138,52 @@ def _( assert include_last_offset == True batch_size = offsets.shape[0] - 1 return qweight.new_empty(batch_size, qweight.shape[1], dtype=qweight.dtype) + + +@register_custom_op("torchao::float8_linear_prepack_cpu") +def _(weight: Tensor, scales: Tensor) -> Tensor: + return weight, scales + + +def float8_linear_cpu( + input: Tensor, + input_scales: Tensor, + weight: Tensor, + weight_scales: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +): + """ + float8 linear operator on CPU. + Args: + input: input tensor. + input_scales: scales for input tensor. + weight: weight tensor. + weight_scales: scales for weight tensor. + bias: optional bias tensor. + out_dtype: output data type. + Returns: + output tensor in out_dtype. + """ + return torch.ops.torchao.float8_linear_cpu.default( + input, + input_scales, + weight, + weight_scales, + bias, + out_dtype, + ) + + +@register_custom_op("torchao::float8_linear_cpu") +def _( + input: Tensor, + input_scales: Tensor, + weight: Tensor, + weight_scales: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +) -> Tensor: + assert weight.dim() in (2, 4) + N = weight.size(0) * weight.size(3) if weight.dim() == 4 else weight.size(0) + return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index b32868b684..f898afc2b7 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -88,6 +88,7 @@ quantize_affine, ) from .quantize_.workflows import ( + Float8OpaqueTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -170,6 +171,7 @@ "Int4TilePackedTo4dTensor", "Float8Tensor", "Int4OpaqueTensor", + "Float8OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 6d928a4477..bff189df18 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -15,8 +15,10 @@ from .granularity import ( Granularity, PerAxis, + PerGroup, PerRow, PerTensor, + PerToken, ) from .quant_primitives import ( MappingType, @@ -78,8 +80,13 @@ def get_block_size( block_size = list(input_shape) block_size[granularity.axis] = 1 return tuple(block_size) - elif isinstance(granularity, PerRow): + elif isinstance(granularity, (PerRow, PerToken)): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + assert input_shape[-1] % granularity.group_size == 0, ( + f"Group size {granularity.group_size} does not divide input shape {input_shape}" + ) + return (1,) * (len(input_shape) - 1) + (granularity.group_size,) raise ValueError(f"Unsupported Granularity: {granularity}") diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index ebc9864f3d..32cc31b21f 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -422,6 +422,11 @@ def _infer_fake_quantize_configs( (act_granularity, weight_granularity) = _normalize_granularity( base_config.granularity ) + assert act_granularity == weight_granularity and isinstance( + act_granularity, (PerTensor, PerRow) + ), ( + "Currently only support same granularity for both activations and weights, and only PerTensor or PerRow" + ) act_config = Float8FakeQuantizeConfig( dtype=base_config.activation_dtype, granularity=act_granularity, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3a6ecc08a7..e57fa56f1d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -69,6 +69,8 @@ KernelPreference, ) from torchao.quantization.quantize_.workflows import ( + Float8OpaqueTensor, + Float8PackingFormat, Float8Tensor, Int4ChooseQParamsAlgorithm, Int4MarlinSparseTensor, @@ -92,6 +94,7 @@ ) from torchao.utils import ( _ConfigDeprecationWrapper, + check_cpu_version, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -1689,6 +1692,26 @@ def _input_activation_quant_func_fp8( return activation +def _input_activation_quant_cpu_fp8( + x: torch.Tensor, + activation_granularity: FP8Granularity, + activation_dtype: torch.dtype, +): + """Dynamic quantize activation to fp8 for CPU.""" + if not isinstance(activation_granularity, PerGroup): + block_size = get_block_size(x.shape, activation_granularity) + else: + group_size = activation_granularity.group_size + block_size = (*([1] * (len(x.shape) - 1)), group_size) + return to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=PlainLayout(), + ) + + def _fp8_mm_compat(weight: torch.Tensor) -> bool: """ Check if a weight tensor meets float8 quantization requirements. @@ -1747,6 +1770,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True version: int = 2 + float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN def __post_init__(self): torch._C._log_api_usage_once( @@ -1757,6 +1781,17 @@ def __post_init__(self): activation_granularity, weight_granularity = _normalize_granularity( self.granularity ) + if self.float8_packing_format == Float8PackingFormat.PLAIN: + assert isinstance(activation_granularity, (PerTensor, PerRow)), ( + f"Unsupported granularity {activation_granularity}, only PerTensor or PerRow are supported." + ) + assert isinstance(weight_granularity, (PerTensor, PerRow)), ( + f"Unsupported granularity {weight_granularity}, only PerTensor or PerRow are supported." + ) + if not isinstance(activation_granularity, type(weight_granularity)): + raise ValueError( + f"Different granularities for activation and weight are not supported: {activation_granularity, weight_granularity}" + ) self.granularity = [activation_granularity, weight_granularity] @@ -1774,17 +1809,19 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_value_lb = config.activation_value_lb activation_value_ub = config.activation_value_ub kernel_preference = config.kernel_preference + float8_packing_format = config.float8_packing_format # Ensure works on device - _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity + is_cpu = weight.device.type == "cpu" + _check_hardware_support(granularity, weight.device.type) - if not _fp8_mm_compat(weight): + if not is_cpu and not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked return weight - if isinstance(weight_granularity, PerRow): + if not is_cpu and isinstance(weight_granularity, PerRow): assert weight.dtype == torch.bfloat16, ( "PerRow quantization only works for bfloat16 precision input weight" ) @@ -1824,14 +1861,26 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): kernel_preference=kernel_preference, ) - quantized_weight = Float8Tensor.from_hp( - weight, - float8_dtype=weight_dtype, - granularity=weight_granularity, - mm_config=mm_config, - kernel_preference=kernel_preference, - act_quant_kwargs=act_quant_kwargs, - ) + if float8_packing_format == Float8PackingFormat.PLAIN: + quantized_weight = Float8Tensor.from_hp( + weight, + float8_dtype=weight_dtype, + granularity=weight_granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + act_quant_kwargs=act_quant_kwargs, + ) + elif float8_packing_format == Float8PackingFormat.OPAQUE: + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = Float8OpaqueTensor.from_hp( + weight, + block_size=block_size, + act_quant_kwargs=act_quant_kwargs, + ) + else: + raise ValueError( + f"Unsupported float8 packing format: {float8_packing_format}" + ) return quantized_weight @@ -1840,16 +1889,21 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): def _float8_dynamic_activation_float8_weight_transform( module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig ): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( "applying float8 dynamic activation quant requires module to have weight attribute" + f"but {module} does not have one" ) + + assert ( + check_cpu_version(module.weight.device, "2.6.0") + or is_sm_at_least_89() + or is_MI300() + ), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+ or on CPU with PyTorch >= 2.6.0" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( module.weight, config ) @@ -1959,6 +2013,9 @@ def _float8_static_activation_float8_weight_transform( weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) + assert activation_granularity == weight_granularity, ( + "Different granularities for activation and weight are not supported" + ) assert isinstance(activation_granularity, PerTensor), ( "Static quantization only supports PerTensor granularity" ) diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 229c94c73a..2c033c6425 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,3 +1,7 @@ +from .float8.float8_opaque_tensor import ( + Float8OpaqueTensor, +) +from .float8.float8_packing_format import Float8PackingFormat from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -36,7 +40,9 @@ "Int4MarlinSparseTensor", "Int4PlainInt32Tensor", "Int4TilePackedTo4dTensor", + "Float8OpaqueTensor", "Float8Tensor", + "Float8PackingFormat", "QuantizeTensorToFloat8Kwargs", "Int4OpaqueTensor", "Int4ChooseQParamsAlgorithm", diff --git a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py new file mode 100644 index 0000000000..39d949f02a --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py @@ -0,0 +1,227 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional + +import torch + +from torchao.quantization.granularity import ( + PerGroup, +) +from torchao.quantization.observer import get_block_size +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, +) +from torchao.utils import ( + TorchAOBaseTensor, +) + +from .float8_tensor import QuantizeTensorToFloat8Kwargs + +__all__ = [ + "Float8OpaqueTensor", +] + +aten = torch.ops.aten + + +class Float8OpaqueTensor(TorchAOBaseTensor): + """ + Float8 dynamic activation float8 weight on CPU. The weight tensor is reordered to a blocked layout + for better memory locality from [N, K] to [N/block_n, K/block_k, block_k, block_n], where block_n = 32 + and block_k depends on group-size for quantization (=32/64/128). And the innermost block with shape + [block_k, block_n] may be further reordered to VNNI layout depending on supported CPU ISA. + + Tensor Attributes: + qdata: Reordered float8 weight on CPU with shape = [N/block_n, K/block_k, block_k, block_n]. + scale: Scale tensor for weight, dtype = float32. For per-group/row quantization, shape = + [N / block_n, num_groups, block_n]. For per-tensor quantization, shape = [1]. + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity. for groupwise quantization, + block_size is (1, group_size). we only support group_size = 32/64/128. For per-row + quantization, blocks_size is (1, K). For per-tensor quantization, block_size is (N, K). + shape: shape of the original Tensor + act_quant_kwargs: the kwargs for from_hp + """ + + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["block_size", "act_quant_kwargs"] + + def __new__( + cls, + qdata, + scale, + block_size, + act_quant_kwargs, + ): + if qdata.ndim == 2: + shape = qdata.shape + else: + assert qdata.ndim == 4 + shape = torch.Size( + [qdata.size(0) * qdata.size(3), qdata.size(1) * qdata.size(2)] + ) + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + ): + self.qdata = qdata + self.scale = scale + self.block_size = block_size + self.act_quant_kwargs = act_quant_kwargs + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, {self.act_quant_kwargs=}" + + @classmethod + def from_hp( + cls, + hp_tensor: torch.Tensor, + block_size: List[int], + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, + ): + assert hp_tensor.ndim == 2 and hp_tensor.device.type == "cpu", ( + f"Expecting 2D tensor on CPU, but got: {hp_tensor.shape} on {hp_tensor.device.type}" + ) + assert len(block_size) == hp_tensor.ndim + N = hp_tensor.size(0) + K = hp_tensor.size(-1) + assert (block_size[0] == 1 or block_size[0] == N) and block_size[1] in ( + 32, + 64, + 128, + K, + ), f"Unsupported block_size: {block_size} for tensor shape {hp_tensor}" + assert act_quant_kwargs is not None, ( + "Activation quantization args must be provided for Float8OpaqueTensor" + ) + act_per_group_quant = isinstance(act_quant_kwargs.granularity, PerGroup) + wei_per_group_quant = block_size[1] < K + if act_per_group_quant: + group_size = act_quant_kwargs.granularity.group_size + if wei_per_group_quant: + # weight_tensor is also per group quantized + assert block_size[1] == group_size, ( + "input and weight should have the same group size but got" + f" {block_size[1]} and {group_size}" + ) + if act_per_group_quant or wei_per_group_quant: + assert N % 32 == 0, ( + f"Expecting out_features {N} to be multiple of 32, but got {N}" + ) + assert K % block_size[1] == 0, ( + f"Expecting in_features {K} to be multiple of group_size {block_size[1]}, but got {K}" + ) + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack scales from [N, num_groups] to [N / block_n, num_groups, block_n]. + packed_weight, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu( + data, scale + ) + + return Float8OpaqueTensor( + qdata=packed_weight, + scale=packed_scale, + block_size=block_size, + act_quant_kwargs=act_quant_kwargs, + ) + + +implements = Float8OpaqueTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert input_tensor.device.type == "cpu", ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert isinstance(weight_tensor, Float8OpaqueTensor), ( + f"Expected weight_tensor to be Float8OpaqueTensor, got: {type(weight_tensor)}" + ) + assert weight_tensor.ndim in [2, 4] + assert input_tensor.size(-1) == weight_tensor.size(-1), ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + act_mat = input_tensor.contiguous() + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # activation float8 quantization + if ( + weight_tensor.act_quant_kwargs is not None + and weight_tensor.act_quant_kwargs.granularity is not None + ): + granularity = weight_tensor.act_quant_kwargs.granularity + if isinstance(granularity, PerGroup): + group_size = granularity.group_size + if weight_tensor.block_size[1] < weight_tensor.size(-1): + # weight_tensor is also per group quantized + assert weight_tensor.block_size[1] == group_size, ( + "input and weight should have the same group size but got" + f" {weight_tensor.block_size[1]} and {group_size}" + ) + act_block_size = get_block_size(act_mat.shape, granularity) + act_scale = _choose_scale_float8( + act_mat, + float8_dtype=torch.float8_e4m3fn, + block_size=act_block_size, + ) + act_mat = _quantize_affine_float8(act_mat, act_scale, torch.float8_e4m3fn) + else: + raise NotImplementedError( + "Activation quantization args not provided for Float8OpaqueTensor" + ) + + # float8 quantized linear operation + y = torch.ops.torchao.float8_linear_cpu.default( + act_mat, + act_scale, + packed_weight, + scale, + bias.float() if bias is not None else bias, # requires bias to be float + torch.float, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + + +Float8OpaqueTensor.__module__ = "torchao.quantization" + +# Allow a model with Float8OpaqueTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8OpaqueTensor]) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py new file mode 100644 index 0000000000..30cf863ac8 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class Float8PackingFormat(str, Enum): + """Packing format for quantized data in Float8 Tensor subclasses in torchao, represents how + the values in quantized data are packed and laid out in memory. + """ + + """ + plain means the format that quantized Tensor data lays out elements in Tensor sequentially, + for example, for a Tensor of shape (4, 6): + a_0_0, a_0_1, ..., a_0_5, + ... + a_3_0, a_3_1, ..., a_3_5 + + """ + PLAIN = "plain" + + """ + Opaque packing format that's used for tensors that does not have a predefined packing format + (that may be decided on hardware, tensor shape, library availability etc.) and it's not + needed for the rest of the system to understand the specific format that's adopted. + """ + OPAQUE = "opaque"