From 736e1f1c1618f75377de900bf877f1af4b82861b Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 7 Jul 2025 17:27:25 +0000 Subject: [PATCH 01/15] [CPU] Add layout and implementation for dynamic float8 act float8 weight on CPU --- .../test_dynamic_float8_linear_cpu.py | 110 ++++ torchao/csrc/cpu/float8_linear.cpp | 522 ++++++++++++++++++ torchao/dtypes/__init__.py | 2 + torchao/dtypes/affine_quantized_tensor_ops.py | 8 + torchao/dtypes/floatx/__init__.py | 4 + .../dyn_float8_act_float8_wei_cpu_layout.py | 281 ++++++++++ torchao/ops.py | 70 +++ torchao/quantization/quant_api.py | 63 ++- 8 files changed, 1046 insertions(+), 14 deletions(-) create mode 100644 test/quantization/test_dynamic_float8_linear_cpu.py create mode 100644 torchao/csrc/cpu/float8_linear.cpp create mode 100644 torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py diff --git a/test/quantization/test_dynamic_float8_linear_cpu.py b/test/quantization/test_dynamic_float8_linear_cpu.py new file mode 100644 index 0000000000..2ccfca1dc7 --- /dev/null +++ b/test/quantization/test_dynamic_float8_linear_cpu.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import copy +import unittest + +import torch +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao import quantize_ +from torchao.dtypes import ( + Float8DynamicActFloat8WeightCPULayout, + PlainLayout, +) +from torchao.quantization import PerRow +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, +) +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, +) + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, K=64, N=32, bias=False): + super().__init__() + self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) + self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) + + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): + return ( + torch.randn( + batch_size, self.linear1.in_features, dtype=dtype, device=device + ), + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestDynamicFloat8Linear(TestCase): + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + m2 = copy.deepcopy(m) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + quantize_( + m, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + layout=Float8DynamicActFloat8WeightCPULayout(), + ), + ) + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + quantize_( + m2, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + layout=PlainLayout(), + ), + ) + torch._dynamo.reset() # may segfault without this + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) + atol, rtol = 1e-6, 1e-6 + if dtype == torch.bfloat16: + atol, rtol = 1.6e-2, 3e-3 + elif dtype == torch.half: + atol, rtol = 6e-3, 2e-3 + assert torch.allclose(y, y2, atol=atol, rtol=rtol) + # Test get_plain by dequantize() + dqw1 = m.linear1.weight.original_weight_tensor.dequantize() + dqw2 = m.linear2.weight.original_weight_tensor.dequantize() + dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() + dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() + assert torch.allclose(dqw1, dqw1_ref) + assert torch.allclose(dqw2, dqw2_ref) + + +common_utils.instantiate_parametrized_tests(TestDynamicFloat8Linear) + + +if __name__ == "__main__": + run_tests() diff --git a/torchao/csrc/cpu/float8_linear.cpp b/torchao/csrc/cpu/float8_linear.cpp new file mode 100644 index 0000000000..ed1f4b017c --- /dev/null +++ b/torchao/csrc/cpu/float8_linear.cpp @@ -0,0 +1,522 @@ +#include +#include +#include +#include + +namespace torchao { + +namespace { + +#define BLOCK_N 32 + +static bool cpublas_checked = false; +static bool cpublas_can_pack = false; + +bool cpublas_could_pack() { + // the could_pack check requires AMX support implicitly + if (cpublas_checked) { + return cpublas_can_pack; + } + cpublas_can_pack = at::native::cpublas::could_pack(at::kBFloat16); + cpublas_checked = true; + return cpublas_can_pack; +} + +/* +return: packed_weight, packed_scales +*/ +std::tuple +float8_linear_prepack_impl( + const at::Tensor& weight, + const at::Tensor& scales) { + // weight shape = [N, K] + // scales shape = [N, G] + TORCH_CHECK(weight.dim() == 2, + "Float8 linear CPU: Weight should be a 2D tensor for packing"); + TORCH_CHECK(weight.size(1) % 2 == 0, + "Float8 linear CPU: Weight should have even number of columns for packing"); + + auto new_scales = scales; + if (new_scales.dim() == 1) { + new_scales.unsqueeze_(1); + } + new_scales = new_scales.to(at::kFloat); + int N = weight.size(0); + int K = weight.size(1); + int G = scales.size(1); + int group_size = K / G; + int block_k = group_size > 128 ? 128 : group_size; + while (K % block_k != 0) { + block_k /= 2; + } + TORCH_CHECK(block_k > 0 && block_k <= group_size, + "Float8 linear CPU: Invalid block_k size, should be in (0, group_size]"); + constexpr int block_n = BLOCK_N; + int Nc = N / block_n; + int Kc = K / block_k; + + // Reorder weight to [N/block_n, K/block_k, block_k, block_n] + // Reorder scales to [N/block_n, G, block_n] + auto weight_view = weight.view({Nc, block_n, Kc, block_k}); + at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); + at::Tensor blocked_weight; + at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + +#if defined(CPU_CAPABILITY_AVX512) + if (cpublas_could_pack()) { + constexpr int vnni_size = 2; // for float16 + blocked_weight = at::empty({Nc, Kc, block_k, block_n}, weight.options()); + auto weight_ptr = reinterpret_cast(weight_reordered.data_ptr()); + auto blocked_weight_ptr = reinterpret_cast(blocked_weight.data_ptr()); + int64_t num_blocks = Nc * Kc; + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + auto in_ptr = weight_ptr + i * block_k * block_n; + auto out_ptr = blocked_weight_ptr + i * block_k * block_n; + + // Reorder weight block to VNNI + // plain shape = [block_k, block_n] + // packed shape = [block_k / VNNI_SIZE, block_n, VNNI_SIZE] viewed as [block_k, block_n] + constexpr int n_group_size = 8; + constexpr int n_group = block_n / n_group_size; // 4 + for (int nb = 0; nb < n_group; ++nb) { + for (int k = 0; k < block_k; k += vnni_size) { + for (int ni = 0; ni < n_group_size; ++ni) { + for (int ki = 0; ki < vnni_size; ++ki) { + int src_idx = nb * n_group_size + ni + (k + ki) * block_n; + int dst_idx = (nb * n_group_size + ni) * vnni_size + k * block_n + ki; + *(out_ptr + dst_idx) = *(in_ptr + src_idx); + } + } + } + } + } + }); + } else +#endif + { + blocked_weight = weight_reordered; + } + + return std::make_tuple(std::move(blocked_weight), std::move(blocked_scales)); +} + +#if defined(CPU_CAPABILITY_AVX512) +alignas(64) static uint16_t e4m3_to_16bit[256]; + +template +static void initialize_e4m3_to_16bit_tables() { + // run only once + static bool initialized_16bit = false; + if (!initialized_16bit) { + for (uint8_t u8 = 0; u8 < 256; ++u8) { + auto value = static_cast(c10::bit_cast(u8)); + uint16_t value_bits = c10::bit_cast(value); + e4m3_to_16bit[u8] = value_bits; + if (u8 == 255) { + break; + } + } + initialized_16bit = true; + } +} + +template +static void cvt_e4m3_16bit_intrinsic_lut( + const at::Float8_e4m3fn* __restrict__ in, + T* out, + int64_t len) { + for (size_t i = 0; i < len; i += 64) { + __m512i fp8_vec = _mm512_loadu_si512((__m512i*)&in[i]); + __m128i group0 = _mm512_castsi512_si128(fp8_vec); + __m128i group1 = _mm512_extracti32x4_epi32(fp8_vec, 1); + __m128i group2 = _mm512_extracti32x4_epi32(fp8_vec, 2); + __m128i group3 = _mm512_extracti32x4_epi32(fp8_vec, 3); + + __m512i indices0 = _mm512_cvtepu8_epi32(group0); + __m512i indices1 = _mm512_cvtepu8_epi32(group1); + __m512i indices2 = _mm512_cvtepu8_epi32(group2); + __m512i indices3 = _mm512_cvtepu8_epi32(group3); + + // Gather BF16 conversion results from the lookup table. + __m512i bf16_i32_vec0 = _mm512_i32gather_epi32(indices0, e4m3_to_16bit, 2); + __m512i bf16_i32_vec1 = _mm512_i32gather_epi32(indices1, e4m3_to_16bit, 2); + __m512i bf16_i32_vec2 = _mm512_i32gather_epi32(indices2, e4m3_to_16bit, 2); + __m512i bf16_i32_vec3 = _mm512_i32gather_epi32(indices3, e4m3_to_16bit, 2); + + // Helper lambda: Convert 16 32-bit ints (in a __m512i) to 16 16-bit ints. + auto convert_32_to_16 = [](__m512i vec) -> __m256i { + return _mm512_cvtepi32_epi16(vec); + }; + + __m256i bf16_i16_vec0 = convert_32_to_16(bf16_i32_vec0); + __m256i bf16_i16_vec1 = convert_32_to_16(bf16_i32_vec1); + __m256i bf16_i16_vec2 = convert_32_to_16(bf16_i32_vec2); + __m256i bf16_i16_vec3 = convert_32_to_16(bf16_i32_vec3); + + _mm256_storeu_si256((__m256i*)(out + i + 0), bf16_i16_vec0); + _mm256_storeu_si256((__m256i*)(out + i + 16), bf16_i16_vec1); + _mm256_storeu_si256((__m256i*)(out + i + 32), bf16_i16_vec2); + _mm256_storeu_si256((__m256i*)(out + i + 48), bf16_i16_vec3); + } +} + +static void _convert_B_to_bf16( + const at::Float8_e4m3fn* __restrict__ B, + at::BFloat16* dqB, + int64_t len) { + initialize_e4m3_to_16bit_tables(); + int tail = len % 64; + cvt_e4m3_16bit_intrinsic_lut(B, dqB, len - tail); + for (int i = len - tail; i < len; ++i) { + dqB[i] = (at::BFloat16)B[i]; + } +} + +static void _convert_A_to_bf16( + const at::Float8_e4m3fn* __restrict__ A, + at::BFloat16* dqA, + int64_t M, + int64_t K, + int64_t lda) { + initialize_e4m3_to_16bit_tables(); + for (int m = 0; m < M; ++m) { + int tail = K % 64; + int body = K - tail; + cvt_e4m3_16bit_intrinsic_lut(A + m * lda, dqA + m * K, body); + for (int k = body; k < K; ++k) { + dqA[m * K + k] = (at::BFloat16)A[m * lda + k]; + } + } +} + +template +static void _dequant_and_store( + float* __restrict__ output, + const float* __restrict__ input, + const float* __restrict__ scale_a, + const float* __restrict__ scale_b, + int M, + int ldi, + int ldo, + int ldsa = 1) { + for (int m = 0; m < M; ++m) { + float a_scale = *(scale_a + m * ldsa); + __m512 va_scale = _mm512_set1_ps(a_scale); + int n = 0; +#pragma GCC unroll 2 + for (; n < N; n += 16) { + __m512 vc_f = _mm512_loadu_ps(input + m * ldi + n); + __m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale); + __m512 vb_s = _mm512_loadu_ps(scale_b + n); + vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s); + if constexpr (accum) { + __m512 vo = _mm512_loadu_ps(output + m * ldo + n); + _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul)); + } else { + _mm512_storeu_ps(output + m * ldo + n, vc_f_mul); + } + } + for (; n < N; ++n) { + float dq_val = input[m * ldi + n] * a_scale * scale_b[n]; + if constexpr (accum) { + output[m * ldo + n] += dq_val; + } else { + output[m * ldo + n] = dq_val; + } + } + } +} + +#else +static void _convert_B_to_bf16( + const at::Float8_e4m3fn* B, + at::BFloat16* dqB, + int64_t len) { + for (int i = 0; i < len; ++i) { + dqB[i] = (at::BFloat16)B[i]; + } +} + +static void _convert_A_to_bf16( + const at::Float8_e4m3fn* __restrict__ A, + at::BFloat16* dqA, + int64_t M, + int64_t K, + int64_t lda) { + for (int m = 0; m < M; ++m) { + for (int k = 0; k < K; ++k) { + dqA[m * K + k] = (at::BFloat16)A[m * lda + k]; + } + } +} +#endif + +template +void _dequant_gemm_accum( + float* C, + const at::Float8_e4m3fn* A, + const float* scales_a, + const at::Float8_e4m3fn* B, + const float* scales_b, + int64_t M, + int64_t K, + int64_t lda, + int64_t ldc) { + // Compute GEMM fp8 * fp8 -> fp32 + // Then apply scales and store results + at::BFloat16 dqB[K * N]; + _convert_B_to_bf16(B, dqB, K * N); + at::BFloat16 dqA[M * K]; + _convert_A_to_bf16(A, dqA, M, K, lda); +#if defined(CPU_CAPABILITY_AVX512) + if constexpr (cpublas_can_pack) { + float C_f32[M * N]; + at::native::cpublas::brgemm( + M, + N, + K, + K /*lda*/, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + dqA, + dqB, + C_f32, + true /* is_vnni */); + _mm_prefetch(B + N * K, _MM_HINT_T0); + _mm_prefetch(A + K, _MM_HINT_T0); + _dequant_and_store( + C, + C_f32, + scales_a, + scales_b, + M, + N /*ldi*/, + ldc, + 1 /*ldsa*/); + } else +#endif + { + for (int64_t i = 0; i < M; ++i) { + for (int64_t j = 0; j < N; ++j) { + float sum = 0; + for (int64_t k = 0; k < K; ++k) { + sum += ((float)dqA[i * K + k] * dqB[k * N + j]); + } + C[i * ldc + j] += sum * scales_a[i] * scales_b[j]; + } + } + } +} + +template +inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) { + if (bias_ptr) { + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 bias_vec = _mm512_loadu_ps(bias_ptr + j); + _mm512_storeu_ps(y_buf + i * N + j, bias_vec); + } +#endif + for (; j < N; ++j) { + y_buf[i * N + j] = bias_ptr[j]; + } + } + } else { // initialize to zero + for (int i = 0; i < m; ++i) { + int j = 0; +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 zero_vec = _mm512_setzero_ps(); + _mm512_storeu_ps(y_buf + i * N + j, zero_vec); + } +#endif + for (; j < N; ++j) { + y_buf[i * N + j] = 0; + } + } + } +} + +template +inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, /* int64_t n, */ int64_t lda) { + for (int i = 0; i < m; ++i) { + int j = 0; + if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = y_buf[i * N + j]; + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]); + } + } else if constexpr (std::is_same::value) { +#if defined(CPU_CAPABILITY_AVX512) +#pragma GCC unroll 2 + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); + } +#endif + for (; j < N; ++j) { + c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]); + } + } else { + TORCH_CHECK(false, "Unsupported output dtype"); + } + } +} + +template +void _float8_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const std::optional& bias, + at::Tensor& output) { + // input shape = [..., K] + // input is per token quantized + int64_t K = input.size(-1); + auto input_view = input.view({-1, K}); + int64_t M = input_view.size(0); + TORCH_CHECK(input_scales.numel() == M, "Float8 linear: unexpected input scales shape"); + + // weight shape = [Nc, Kc, block_k, block_n] + // scales shape = [Nc, G, block_n] + int64_t Nc = weight.size(0); + int64_t Kc = weight.size(1); + int64_t block_k = weight.size(2); + constexpr int64_t block_n = BLOCK_N; + TORCH_CHECK(weight.size(3) == block_n, "Float8 linear: unexpected weight shape"); + int64_t N = Nc * block_n; + TORCH_CHECK(K == Kc * block_k, "Float8 linear: weight and input shapes mismatch"); + int64_t block_m = [&]() -> long { + if (M <= 48) { + return M; + } else if (M < 64) { + return 32; + } else if (M < 96) { + return 64; + } else { + return 128; + } + }(); + int64_t Mc = (M + block_m - 1) / block_m; + bool parallel_on_M = M > 128; + int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc; + + // scales shape = [Nc, G, block_n] + int64_t num_groups = weight_scales.size(1); + int64_t group_size = K / num_groups; + TORCH_CHECK(group_size % block_k == 0, + "Float8 linear: group_size should be divisible by block_k"); + int64_t block_per_group = group_size / block_k; + + const at::Float8_e4m3fn* a_ptr = input_view.data_ptr(); + const float* a_scales_ptr = input_scales.data_ptr(); + const at::Float8_e4m3fn* b_ptr = weight.data_ptr(); + const float* b_scales_ptr = weight_scales.data_ptr(); + out_dtype* c_ptr = output.data_ptr(); + const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; + + at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + for (const auto i : c10::irange(begin, end)) { + int64_t mc = parallel_on_M ? i / Nc : 0; + int64_t nc = parallel_on_M ? i % Nc : i; + int64_t mc_end = parallel_on_M ? mc + 1 : Mc; + + for (int mci = mc; mci < mc_end; ++mci) { + int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; + alignas(64) float y_buf[m_size][block_n]; + // copy bias to y_buf if bias is not None + auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; + copy_bias(bias_data, y_buf[0], m_size); + for (int kci = 0; kci < Kc; ++kci) { + _dequant_gemm_accum( + y_buf[0] /*C*/, + a_ptr + mci * block_m * K + kci * block_k /*A*/, + a_scales_ptr + mci * block_m /*scales_a*/, + b_ptr + (nc * Kc + kci) * block_n * block_k /*B*/, + b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*scales_b*/, + m_size /*M*/, + block_k /*K*/, + K /*lda*/, + block_n /*ldc*/); + } + // store y_buf to output with dtype conversion + store_out( + y_buf[0], + c_ptr + mci * block_m * N + nc * block_n, + m_size, + N /*lda*/); + } + } + if constexpr (cpublas_can_pack) { + at::native::cpublas::brgemm_release(); + } + }); +} + +at::Tensor float8_linear_impl( + const at::Tensor& input, + const at::Tensor& input_scales, + const at::Tensor& weight, + const at::Tensor& weight_scales, + const std::optional& bias, + at::ScalarType output_dtype) { + static bool cpublas_can_pack = cpublas_could_pack(); + auto out_sizes = input.sizes().vec(); + int64_t N = weight.size(0) * weight.size(-1); + out_sizes.back() = N; + auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); + +#define call__float8_linear_impl(cpublas_can_pack) \ + AT_DISPATCH_FLOATING_TYPES_AND2( \ + at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "float8_linear_cpu", [&] { \ + _float8_linear_impl( \ + input, \ + input_scales, \ + weight, \ + weight_scales, \ + bias, \ + output); \ + }); + + if (cpublas_can_pack) { + call__float8_linear_impl(true); + } else { + call__float8_linear_impl(false); + } + return output; +} + +} // anonymous namespace + +TORCH_LIBRARY_IMPL(torchao, CPU, m) { + m.impl("torchao::float8_linear_prepack_cpu", &float8_linear_prepack_impl); + m.impl("torchao::float8_linear_cpu", &float8_linear_impl); +} + +} // namespace torchao diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index d6b1b9c440..476df2aace 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -12,6 +12,7 @@ from .fbgemm_int4_tensor import FbgemmInt4Tensor, to_fbgemm_int4 from .floatx import ( CutlassSemiSparseLayout, + Float8DynamicActFloat8WeightCPULayout, Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 @@ -70,4 +71,5 @@ "FbgemmFp8Tensor", "Int8DynamicActInt4WeightCPULayout", "Int4GroupwisePreshuffleTensor", + "Float8DynamicActFloat8WeightCPULayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index 8b028352e4..a28e764cb8 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -15,6 +15,10 @@ _linear_fp8_act_fp8_weight_sparse_cutlass_check, _linear_fp8_act_fp8_weight_sparse_cutlass_impl, ) +from torchao.dtypes.floatx.dyn_float8_act_float8_wei_cpu_layout import ( + _float8_linear_cpu_check, + _float8_linear_cpu_impl, +) from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl, @@ -255,6 +259,10 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, ), + ( + _float8_linear_cpu_check, + _float8_linear_cpu_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 7e634a5211..05744e6b50 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,6 +1,9 @@ from .cutlass_semi_sparse_layout import ( CutlassSemiSparseLayout, ) +from .dyn_float8_act_float8_wei_cpu_layout import ( + Float8DynamicActFloat8WeightCPULayout, +) from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( FloatxTensorCoreLayout, @@ -14,4 +17,5 @@ "from_scaled_tc_floatx", "Float8Layout", "CutlassSemiSparseLayout", + "Float8DynamicActFloat8WeightCPULayout", ] diff --git a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py new file mode 100644 index 0000000000..2c37193efb --- /dev/null +++ b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.utils._python_dispatch import ( + return_and_correct_aliasing, +) + +from torchao.dtypes.affine_quantized_tensor import ( + AffineQuantizedTensor, + register_layout, +) +from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout, is_device +from torchao.utils import ( + TORCH_VERSION_AT_LEAST_2_6, + fill_defaults, +) + +from ..uintx.int4_cpu_layout import ( + _is_float, +) + +aten = torch.ops.aten + + +@dataclass(frozen=True) +class Float8DynamicActFloat8WeightCPULayout(Layout): + """Layout class for float8 da8w8 CPU layout for affine quantized tensor""" + + pass + + +@register_layout(Float8DynamicActFloat8WeightCPULayout) +class Float8DynActFloat8WeiCpuAQTTensorImpl(AQTTensorImpl): + """TensorImpl for float8 da8w8 CPU layout for affine quantized tensor""" + + def __new__( + cls, + packed_weight: torch.Tensor, + scales: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + kwargs = {} + kwargs["device"] = packed_weight.device + kwargs["layout"] = ( + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout + ) + kwargs["dtype"] = packed_weight.dtype + kwargs["requires_grad"] = False + shape = packed_weight.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + packed_weight: torch.Tensor, + scales: torch.Tensor, + transposed: bool, + _layout: Layout, + ): + self.packed_weight = packed_weight + self.scales = scales + self.transposed = transposed + self._layout = _layout + + def __tensor_flatten__(self): + return ["packed_weight", "scales"], [ + self.transposed, + self._layout, + ] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + packed_weight, scales = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scales"], + ) + (_layout, transposed) = tensor_attributes + return cls(packed_weight, scales, transposed, _layout) + + @classmethod + def from_plain( + cls, + data: torch.Tensor, + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + _layout: Layout, + ): + assert isinstance(_layout, Float8DynamicActFloat8WeightCPULayout) + assert data.dtype == torch.float8_e4m3fn, ( + "Float8 DA8W8 CPU: expects float8_e4m3fn weight" + ) + if scale.dim() == 1: + scale.unsqueeze_(-1) + scale = scale.to(torch.float) + + K = data.size(-1) + if K % 32 == 0: + weight_packed, scales = torch.ops.torchao.float8_linear_prepack_cpu( + data, scale + ) + else: + weight_packed = data + scales = scale + _layout = PlainLayout + return cls(weight_packed, scales, False, _layout) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.packed_weight), + fn(self.scales), + self.transposed, + self._layout, + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + if func is aten.slice.Tensor: + self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) + if dim in [0, 1]: + assert step == 1, "Only step == 1 is supported in slicing right now" + data, scale = self.get_plain() + data_len = data.shape[dim] + scale_len = scale.shape[dim] + ratio = data_len / scale_len + start_scale = int(start / ratio) + end_scale = int(end / ratio) + + data = aten.slice.Tensor(data, dim, start, end, step) + scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) + # this is to handle padding + data, scale = self._layout.post_process(data, scale, self.block_size) + sliced = self.from_plain(data, scale, self._layout) + return return_and_correct_aliasing(func, args, kwargs, sliced) + else: + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) + + raise NotImplementedError( + f"{cls.__name__} dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + @property + def block_size(self): + assert len(self.packed_weight.shape) == 2 + weight_shape = self.packed_weight.shape + N = weight_shape[0] + K = weight_shape[1] + groups = self.scales.numel() // N + group_size = K // groups + return (1, group_size) + + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if self._layout == PlainLayout: + # If the layout is PlainLayout, return the packed weight and scales directly + return ( + self.packed_weight, + self.scales, + torch.zeros_like(self.scales), + ) + # Unpack weight by linear(eye(K), packed_weight).t() + packed_w_shape = self.packed_weight.shape + if len(packed_w_shape) == 4: + K = packed_w_shape[1] * packed_w_shape[2] + else: + K = packed_w_shape[1] + x = torch.eye(K).to(torch.float8_e4m3fn) + x_scale = torch.ones(K).float() + w_scale = torch.ones_like(self.scales).float() + plain_weight = torch.ops.torchao.float8_linear_cpu.default( + x, + x_scale, + self.packed_weight, + w_scale, + None, # bias + torch.float, # out_dtype + ) + plain_weight = plain_weight.t().contiguous() + plain_weight = plain_weight.to(torch.float8_e4m3fn) + + if self.scales.dim() == 2: + plain_scales = self.scales + else: + assert self.scales.dim() == 3 + packed_shape = self.scales.shape # [Nc, G, block_n] + plain_scales = ( + self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) + ) + + return plain_weight, plain_scales, torch.zeros_like(plain_scales) + + +def _aqt_is_float8e4m3(aqt): + """Check if an AffineQuantizedTensor is float8_e4m3fn quantized Tensor""" + return aqt.tensor_impl.dtype == torch.float8_e4m3fn + + +def _float8_linear_cpu_check(input_tensor, weight_tensor, bias): + return ( + TORCH_VERSION_AT_LEAST_2_6 + and is_device(input_tensor.device.type, "cpu") + and is_device(weight_tensor.device.type, "cpu") + and (bias is None or is_device(bias.device.type, "cpu")) + and isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_float8e4m3(input_tensor) + and _is_float(input_tensor.dtype) + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_float8e4m3(weight_tensor) + and _is_float(weight_tensor.dtype) + and isinstance(weight_tensor._layout, Float8DynamicActFloat8WeightCPULayout) + ) + + +def _float8_linear_cpu_impl(input_tensor, weight_tensor, bias): + assert TORCH_VERSION_AT_LEAST_2_6, ( + f"Requires PyTorch version at least 2.6, but got: {torch.__version__}" + ) + assert is_device(input_tensor.device.type, "cpu"), ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert weight_tensor.block_size[0] == 1, ( + f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"need input_tensor shape: {input_tensor.shape} final" + f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " + ) + + act_mat = input_tensor + act = act_mat.tensor_impl.int_data + act_scales = act_mat.tensor_impl.scale + + packed_weight = weight_tensor.tensor_impl.packed_weight + wei_scales = weight_tensor.tensor_impl.scales + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + + # reshape to 2D + act = act.reshape(-1, act.shape[-1]) + + y = torch.ops.torchao.float8_linear_cpu.default( + act.contiguous(), + act_scales, + packed_weight, + wei_scales, + bias.float() if bias is not None else bias, # requires bias to be float + torch.float, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) diff --git a/torchao/ops.py b/torchao/ops.py index babe5506c0..178e98f589 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -70,6 +70,12 @@ lib.define( "da8w4_linear_cpu(Tensor input, Tensor input_scales, Tensor input_qzeros, Tensor weight, Tensor weight_scales, Tensor weight_qzeros, Tensor compensation, Tensor? bias, ScalarType output_dtype) -> Tensor" ) +lib.define( + "float8_linear_prepack_cpu(Tensor weight, Tensor scales) -> (Tensor, Tensor)" +) +lib.define( + "float8_linear_cpu(Tensor input, Tensor input_scales, Tensor weight, Tensor weight_scales, Tensor? bias, ScalarType output_dtype) -> Tensor" +) def register_custom_op(name): @@ -1106,3 +1112,67 @@ def _( assert weight.dim() == 4 N = weight.size(0) * weight.size(3) * 2 return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) + + +def float8_linear_prepack_cpu( + weight: Tensor, + scales: Tensor, +) -> Tensor: + """ + Prepack weights for float8 linear operator on CPU. + Args: + weight: weight tensor. + scales: scales for weight tensor. + Returns: + packed weight, packed scales + """ + return torch.ops.torchao.float8_linear_prepack_cpu.default(weight, scales) + + +@register_custom_op("torchao::float8_linear_prepack_cpu") +def _(weight: Tensor, scales: Tensor) -> Tensor: + return weight, scales + + +def float8_linear_cpu( + input: Tensor, + input_scales: Tensor, + weight: Tensor, + weight_scales: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +): + """ + float8 linear operator on CPU. + Args: + input: input tensor. + input_scales: scales for input tensor. + weight: weight tensor. + weight_scales: scales for weight tensor. + bias: optional bias tensor. + out_dtype: output data type. + Returns: + output tensor in out_dtype. + """ + return torch.ops.torchao.float8_linear_cpu.default( + input, + input_scales, + weight, + weight_scales, + bias, + out_dtype, + ) + + +@register_custom_op("torchao::float8_linear_cpu") +def _( + input: Tensor, + input_scales: Tensor, + weight: Tensor, + weight_scales: Tensor, + bias: Optional[Tensor], + out_dtype: torch.dtype, +) -> Tensor: + assert weight.dim() == 4 + N = weight.size(0) * weight.size(3) + return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7df6995955..7c282fd891 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -82,6 +82,7 @@ TORCH_VERSION_AT_LEAST_2_5, TORCH_VERSION_AT_LEAST_2_6, _is_fbgemm_genai_gpu_available, + check_cpu_version, is_MI300, is_sm_at_least_89, is_sm_at_least_90, @@ -1545,6 +1546,22 @@ def _input_activation_quant_func_fp8( return activation +def _input_activation_quant_cpu_fp8( + x: torch.Tensor, + activation_granularity: FP8Granularity, + activation_dtype: torch.dtype, +): + """Dynamic quantize activation to fp8 for CPU.""" + block_size = get_block_size(x.shape, activation_granularity) + return to_affine_quantized_floatx( + input_float=x, + block_size=block_size, + target_dtype=activation_dtype, + scale_dtype=torch.float32, + _layout=PlainLayout(), + ) + + def _fp8_mm_compat(weight: torch.Tensor) -> bool: """ Check if a weight tensor meets float8 quantization requirements. @@ -1595,10 +1612,13 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None mm_config: Optional[Float8MMConfig] = None set_inductor_config: bool = True + layout: Optional[Layout] = None def __post_init__(self): - if self.mm_config is None: - self.mm_config = Float8MMConfig(use_fast_accum=True) + if self.layout is None: + if self.mm_config is None: + self.mm_config = Float8MMConfig(use_fast_accum=True) + self.layout = Float8Layout(self.mm_config) activation_granularity, weight_granularity = _normalize_granularity( self.granularity @@ -1614,17 +1634,23 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype granularity = config.granularity - mm_config = config.mm_config # Ensure works on device - _check_hardware_support(granularity) activation_granularity, weight_granularity = granularity + is_cpu = weight.device.type == "cpu" + if is_cpu: + assert not ( + isinstance(activation_granularity, PerTensor) + or isinstance(weight_granularity, PerTensor) + ), "PerTensor quantization is not supported for CPU float8 quantization" + else: + _check_hardware_support(granularity) - if not _fp8_mm_compat(weight): + if not is_cpu and not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently # not doing what the user asked return weight - if isinstance(weight_granularity, PerRow): + if not is_cpu and isinstance(weight_granularity, PerRow): assert weight.dtype == torch.bfloat16, ( "PerRow quantization only works for bfloat16 precision input weight" ) @@ -1637,10 +1663,14 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): block_size=block_size, target_dtype=weight_dtype, scale_dtype=torch.float32, - _layout=Float8Layout(mm_config=mm_config), + _layout=config.layout, ) - input_quant_func = _input_activation_quant_func_fp8 + input_quant_func = ( + _input_activation_quant_func_fp8 + if isinstance(config.layout, Float8Layout) + else _input_activation_quant_cpu_fp8 + ) input_quant_kwargs = { "activation_granularity": activation_granularity, "activation_dtype": activation_dtype, @@ -1656,16 +1686,21 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): def _float8_dynamic_activation_float8_weight_transform( module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig ): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) - if config.set_inductor_config: - torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( "applying float8 dynamic activation quant requires module to have weight attribute" + f"but {module} does not have one" ) + + assert ( + check_cpu_version(module.weight.device, "2.6.0") + or is_sm_at_least_89() + or is_MI300() + ), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+ or on CPU with PyTorch >= 2.6.0" + ) + if config.set_inductor_config: + torchao.quantization.utils.recommended_inductor_config_setter() + quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( module.weight, config ) From c23838567f1a5362b5400518969809ce382a0276 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 10 Jul 2025 14:47:50 +0000 Subject: [PATCH 02/15] Refine code --- torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py index 2c37193efb..a5f30eae71 100644 --- a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py +++ b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py @@ -105,13 +105,15 @@ def from_plain( K = data.size(-1) if K % 32 == 0: + # weight is packed to [N / block_n, K / block_k, block_k, block_n] + # The inner block [block_k, block_n] are packed to VNNI layout if AMX is available. weight_packed, scales = torch.ops.torchao.float8_linear_prepack_cpu( data, scale ) else: weight_packed = data scales = scale - _layout = PlainLayout + _layout = PlainLayout() return cls(weight_packed, scales, False, _layout) def _apply_fn_to_data(self, fn): From 3e7d17959591c01fcadacf7f45663020b3837834 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Thu, 10 Jul 2025 14:54:23 +0000 Subject: [PATCH 03/15] refine comments --- .../dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py index a5f30eae71..7b581bf6a6 100644 --- a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py +++ b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py @@ -105,8 +105,9 @@ def from_plain( K = data.size(-1) if K % 32 == 0: - # weight is packed to [N / block_n, K / block_k, block_k, block_n] - # The inner block [block_k, block_n] are packed to VNNI layout if AMX is available. + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack inner blocks [block_k, block_n] to VNNI layout if AMX is available. + # Pack scales from [N, num_groups] to [N / block_n, num_groups, block_n]. weight_packed, scales = torch.ops.torchao.float8_linear_prepack_cpu( data, scale ) From 953ac130787806c1daad4a488d5a19fddaad56b3 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 11 Jul 2025 17:57:46 +0000 Subject: [PATCH 04/15] Check K % num_groups == 0 --- torchao/csrc/cpu/float8_linear.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/torchao/csrc/cpu/float8_linear.cpp b/torchao/csrc/cpu/float8_linear.cpp index ed1f4b017c..c211168069 100644 --- a/torchao/csrc/cpu/float8_linear.cpp +++ b/torchao/csrc/cpu/float8_linear.cpp @@ -429,6 +429,7 @@ void _float8_linear_impl( // scales shape = [Nc, G, block_n] int64_t num_groups = weight_scales.size(1); + TORCH_CHECK(K % num_groups == 0, "K should be divisible by num_groups"); int64_t group_size = K / num_groups; TORCH_CHECK(group_size % block_k == 0, "Float8 linear: group_size should be divisible by block_k"); From cd5380204f76952792c09cc4c0fc626a86c9ed0c Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Mon, 14 Jul 2025 15:30:06 +0000 Subject: [PATCH 05/15] Check N & K % 32 == 0; update UT --- .../test_dynamic_float8_linear_cpu.py | 57 +++++++++++++++++-- torchao/csrc/cpu/float8_linear.cpp | 2 + .../dyn_float8_act_float8_wei_cpu_layout.py | 7 ++- 3 files changed, 59 insertions(+), 7 deletions(-) diff --git a/test/quantization/test_dynamic_float8_linear_cpu.py b/test/quantization/test_dynamic_float8_linear_cpu.py index 2ccfca1dc7..7f81965a0e 100644 --- a/test/quantization/test_dynamic_float8_linear_cpu.py +++ b/test/quantization/test_dynamic_float8_linear_cpu.py @@ -36,9 +36,8 @@ def __init__(self, K=64, N=32, bias=False): def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): return ( - torch.randn( - batch_size, self.linear1.in_features, dtype=dtype, device=device - ), + torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device) + * 0.1, ) def forward(self, x): @@ -88,7 +87,7 @@ def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs): ) torch._dynamo.reset() # may segfault without this y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) - atol, rtol = 1e-6, 1e-6 + atol, rtol = 1e-4, 1e-6 if dtype == torch.bfloat16: atol, rtol = 1.6e-2, 3e-3 elif dtype == torch.half: @@ -102,6 +101,56 @@ def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs): assert torch.allclose(dqw1, dqw1_ref) assert torch.allclose(dqw2, dqw2_ref) + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + def test_dynamic_float8_linear_ref_cpu(self, dtype, x_dim, bias): + device = "cpu" + # the shape is not supported by cpp kernel, so the ref path will be used. + m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) + m2 = copy.deepcopy(m) + bs = 4 + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + + with torch.no_grad(): + quantize_( + m, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + layout=Float8DynamicActFloat8WeightCPULayout(), + ), + ) + y, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the op is not in the code + assert "torch.ops.torchao.float8_linear_cpu.default" not in code[0] + quantize_( + m2, + Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + layout=PlainLayout(), + ), + ) + torch._dynamo.reset() # may segfault without this + y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) + assert torch.allclose(y, y2) + # Test get_plain by dequantize() + dqw1 = m.linear1.weight.original_weight_tensor.dequantize() + dqw2 = m.linear2.weight.original_weight_tensor.dequantize() + dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() + dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() + assert torch.allclose(dqw1, dqw1_ref) + assert torch.allclose(dqw2, dqw2_ref) + common_utils.instantiate_parametrized_tests(TestDynamicFloat8Linear) diff --git a/torchao/csrc/cpu/float8_linear.cpp b/torchao/csrc/cpu/float8_linear.cpp index c211168069..586c3e9439 100644 --- a/torchao/csrc/cpu/float8_linear.cpp +++ b/torchao/csrc/cpu/float8_linear.cpp @@ -44,6 +44,7 @@ float8_linear_prepack_impl( int N = weight.size(0); int K = weight.size(1); int G = scales.size(1); + TORCH_CHECK(K % G == 0, "K should be divisible by num_groups"); int group_size = K / G; int block_k = group_size > 128 ? 128 : group_size; while (K % block_k != 0) { @@ -52,6 +53,7 @@ float8_linear_prepack_impl( TORCH_CHECK(block_k > 0 && block_k <= group_size, "Float8 linear CPU: Invalid block_k size, should be in (0, group_size]"); constexpr int block_n = BLOCK_N; + TORCH_CHECK(N % block_n == 0, "N should be divisible by 32"); int Nc = N / block_n; int Kc = K / block_k; diff --git a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py index 7b581bf6a6..2252f158fa 100644 --- a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py +++ b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py @@ -84,7 +84,7 @@ def __tensor_unflatten__( tensor_data_dict["packed_weight"], tensor_data_dict["scales"], ) - (_layout, transposed) = tensor_attributes + (transposed, _layout) = tensor_attributes return cls(packed_weight, scales, transposed, _layout) @classmethod @@ -103,8 +103,9 @@ def from_plain( scale.unsqueeze_(-1) scale = scale.to(torch.float) + N = data.size(0) K = data.size(-1) - if K % 32 == 0: + if N % 32 == 0 and K % 32 == 0: # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. # Pack inner blocks [block_k, block_n] to VNNI layout if AMX is available. # Pack scales from [N, num_groups] to [N / block_n, num_groups, block_n]. @@ -178,7 +179,7 @@ def block_size(self): return (1, group_size) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if self._layout == PlainLayout: + if isinstance(self._layout, PlainLayout): # If the layout is PlainLayout, return the packed weight and scales directly return ( self.packed_weight, From b4f652070f738224ecace745b8d1bc3586915308 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 9 Sep 2025 10:11:56 +0000 Subject: [PATCH 06/15] [CPU] Add float8OpaqueTensor for daf8wf8 --- torchao/csrc/cpu/dispatcher.h | 191 ++++++++++++ torchao/dtypes/__init__.py | 2 - torchao/dtypes/affine_quantized_tensor_ops.py | 8 - torchao/dtypes/floatx/__init__.py | 4 - .../dyn_float8_act_float8_wei_cpu_layout.py | 285 ------------------ torchao/quantization/__init__.py | 2 + torchao/quantization/observer.py | 9 +- torchao/quantization/quant_api.py | 27 +- .../quantize_/workflows/__init__.py | 4 + .../workflows/float8/float8_opaque_tensor.py | 222 ++++++++++++++ 10 files changed, 446 insertions(+), 308 deletions(-) create mode 100644 torchao/csrc/cpu/dispatcher.h delete mode 100644 torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py create mode 100644 torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py diff --git a/torchao/csrc/cpu/dispatcher.h b/torchao/csrc/cpu/dispatcher.h new file mode 100644 index 0000000000..81edbfa971 --- /dev/null +++ b/torchao/csrc/cpu/dispatcher.h @@ -0,0 +1,191 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include + +template < + typename IntegralType, + int n, + IntegralType First, + IntegralType... Rest> +struct enumerate_dispatcher_helper { + template + inline static void call( + IntegralType i, + const std::function& comparator, + const Lambda1& function, + const Lambda2& fallback, + Args... args) { + if (comparator(i, First)) + function( + std::integral_constant{}, + std::forward(args)...); + else + enumerate_dispatcher_helper::call( + i, comparator, function, fallback, std::forward(args)...); + } +}; + +template +struct enumerate_dispatcher_helper { + template + inline static void call( + IntegralType i, + const std::function& comparator, + const Lambda1& function, + const Lambda2& fallback, + Args... args) { + if (comparator(i, First)) + function( + std::integral_constant{}, + std::forward(args)...); + else + fallback(i, std::forward(args)...); + } +}; + +// dispatch a list of integers to a lambda function +template +struct enumerate_dispatcher { + template + inline static void call( + IntegralType i, + const Lambda1& function, + const Lambda2& fallback, + Args... args) { + enumerate_dispatcher_helper:: + call( + i, + [&](IntegralType a, IntegralType b) { return a == b; }, + function, + fallback, + std::forward(args)...); + } +}; + +// A helper function that returns the last N-1 items of a tuple as a new tuple +template +auto get_last_n_minus_one_impl(TupleType&& t, std::index_sequence) { + return std::tuple_cat(std::make_tuple(std::get(t))...); +} + +// A function that returns the last N-1 items of a tuple as a new tuple +template +auto get_last_n_minus_one(TupleType&& t) { + // Get the size of the tuple + constexpr auto size = + std::tuple_size::type>::value; + // Check if the size is greater than one + return get_last_n_minus_one_impl( + std::forward(t), std::make_index_sequence{}); +} + +template < + typename TupleType, + std::enable_if_t::value == 1, bool> = true> +auto get_last_n_minus_one(TupleType&& t) { + return std::make_tuple(); +} + +template < + typename IntegralTypesProcessed, + typename IntegralTypesToProcess, + typename Dispatchers> +struct product_dispatcher_helper; + +template +struct product_dispatcher_helper< + std::tuple, + std::tuple<>, + std::tuple<>> { + template + inline static void call( + std::tuple<>, + std::tuple constants, + std::tuple<>, + const Lambda1& function, + const Lambda2& fallback, + Args... args) { + function(constants, std::forward(args)...); + } +}; + +template < + typename... IntegralTypeProcessed, + typename... IntegeralTypeToProcess, + typename... Dispatcher> +struct product_dispatcher_helper< + std::tuple, + std::tuple, + std::tuple> { + template + inline static void call( + std::tuple dispatchers, + std::tuple constants, + std::tuple integrals, + const Lambda1& function, + const Lambda2& fallback, + Args... args) { + std::get<0>(dispatchers) + .call( + std::get<0>(integrals), + [&](auto i, Args... args) { + auto new_dispatchers = get_last_n_minus_one(dispatchers); + auto new_constants = + std::tuple_cat(constants, std::tuple(i)); + auto new_integrals = get_last_n_minus_one(integrals); + product_dispatcher_helper< + decltype(new_constants), + decltype(new_integrals), + decltype(new_dispatchers)>:: + call( + new_dispatchers, + new_constants, + new_integrals, + function, + fallback, + std::forward(args)...); + }, + [&](auto i, Args... args) { + fallback( + std::tuple_cat(constants, integrals), + std::forward(args)...); + }, + std::forward(args)...); + } +}; + +template +struct product_dispatcher; + +// dispatch to a carsian product of a list of integers to a lambda function +template +struct product_dispatcher< + std::tuple, + std::tuple> { + template + inline static void call( + std::tuple integrals, + const Lambda1& function, + const Lambda2& fallback, + Args... args) { + static auto dispatchers = std::tuple{}; + product_dispatcher_helper< + std::tuple<>, + std::tuple, + std::tuple>:: + call( + dispatchers, + std::tuple<>{}, + integrals, + function, + fallback, + std::forward(args)...); + } +}; diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 8db94fd0b2..575e154091 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -11,7 +11,6 @@ from .fbgemm_fp8_tensor import FbgemmFp8Tensor, to_fbgemm_fp8 from .floatx import ( CutlassSemiSparseLayout, - Float8DynamicActFloat8WeightCPULayout, Float8Layout, ) from .nf4tensor import NF4Tensor, to_nf4 @@ -68,5 +67,4 @@ "FbgemmFp8Tensor", "Int8DynamicActInt4WeightCPULayout", "Int4GroupwisePreshuffleTensor", - "Float8DynamicActFloat8WeightCPULayout", ] diff --git a/torchao/dtypes/affine_quantized_tensor_ops.py b/torchao/dtypes/affine_quantized_tensor_ops.py index a28e764cb8..8b028352e4 100644 --- a/torchao/dtypes/affine_quantized_tensor_ops.py +++ b/torchao/dtypes/affine_quantized_tensor_ops.py @@ -15,10 +15,6 @@ _linear_fp8_act_fp8_weight_sparse_cutlass_check, _linear_fp8_act_fp8_weight_sparse_cutlass_impl, ) -from torchao.dtypes.floatx.dyn_float8_act_float8_wei_cpu_layout import ( - _float8_linear_cpu_check, - _float8_linear_cpu_impl, -) from torchao.dtypes.floatx.float8_layout import ( _linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl, @@ -259,10 +255,6 @@ def _register_aqt_quantized_linear_dispatches(): _linear_int8_act_int4_weight_cpu_check, _linear_int8_act_int4_weight_cpu_impl, ), - ( - _float8_linear_cpu_check, - _float8_linear_cpu_impl, - ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index 05744e6b50..7e634a5211 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1,9 +1,6 @@ from .cutlass_semi_sparse_layout import ( CutlassSemiSparseLayout, ) -from .dyn_float8_act_float8_wei_cpu_layout import ( - Float8DynamicActFloat8WeightCPULayout, -) from .float8_layout import Float8Layout from .floatx_tensor_core_layout import ( FloatxTensorCoreLayout, @@ -17,5 +14,4 @@ "from_scaled_tc_floatx", "Float8Layout", "CutlassSemiSparseLayout", - "Float8DynamicActFloat8WeightCPULayout", ] diff --git a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py b/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py deleted file mode 100644 index 2252f158fa..0000000000 --- a/torchao/dtypes/floatx/dyn_float8_act_float8_wei_cpu_layout.py +++ /dev/null @@ -1,285 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -from dataclasses import dataclass -from typing import Optional, Tuple - -import torch -from torch.utils._python_dispatch import ( - return_and_correct_aliasing, -) - -from torchao.dtypes.affine_quantized_tensor import ( - AffineQuantizedTensor, - register_layout, -) -from torchao.dtypes.utils import AQTTensorImpl, Layout, PlainLayout, is_device -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, - fill_defaults, -) - -from ..uintx.int4_cpu_layout import ( - _is_float, -) - -aten = torch.ops.aten - - -@dataclass(frozen=True) -class Float8DynamicActFloat8WeightCPULayout(Layout): - """Layout class for float8 da8w8 CPU layout for affine quantized tensor""" - - pass - - -@register_layout(Float8DynamicActFloat8WeightCPULayout) -class Float8DynActFloat8WeiCpuAQTTensorImpl(AQTTensorImpl): - """TensorImpl for float8 da8w8 CPU layout for affine quantized tensor""" - - def __new__( - cls, - packed_weight: torch.Tensor, - scales: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - kwargs = {} - kwargs["device"] = packed_weight.device - kwargs["layout"] = ( - kwargs.get("layout") - if kwargs.get("layout", False) - else packed_weight.layout - ) - kwargs["dtype"] = packed_weight.dtype - kwargs["requires_grad"] = False - shape = packed_weight.shape - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - packed_weight: torch.Tensor, - scales: torch.Tensor, - transposed: bool, - _layout: Layout, - ): - self.packed_weight = packed_weight - self.scales = scales - self.transposed = transposed - self._layout = _layout - - def __tensor_flatten__(self): - return ["packed_weight", "scales"], [ - self.transposed, - self._layout, - ] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - packed_weight, scales = ( - tensor_data_dict["packed_weight"], - tensor_data_dict["scales"], - ) - (transposed, _layout) = tensor_attributes - return cls(packed_weight, scales, transposed, _layout) - - @classmethod - def from_plain( - cls, - data: torch.Tensor, - scale: torch.Tensor, - zero_point: Optional[torch.Tensor], - _layout: Layout, - ): - assert isinstance(_layout, Float8DynamicActFloat8WeightCPULayout) - assert data.dtype == torch.float8_e4m3fn, ( - "Float8 DA8W8 CPU: expects float8_e4m3fn weight" - ) - if scale.dim() == 1: - scale.unsqueeze_(-1) - scale = scale.to(torch.float) - - N = data.size(0) - K = data.size(-1) - if N % 32 == 0 and K % 32 == 0: - # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. - # Pack inner blocks [block_k, block_n] to VNNI layout if AMX is available. - # Pack scales from [N, num_groups] to [N / block_n, num_groups, block_n]. - weight_packed, scales = torch.ops.torchao.float8_linear_prepack_cpu( - data, scale - ) - else: - weight_packed = data - scales = scale - _layout = PlainLayout() - return cls(weight_packed, scales, False, _layout) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.packed_weight), - fn(self.scales), - self.transposed, - self._layout, - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - kwargs = {} if kwargs is None else kwargs - - if func is aten.detach.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - if func is aten.clone.default: - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - if func is aten.slice.Tensor: - self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) - if dim in [0, 1]: - assert step == 1, "Only step == 1 is supported in slicing right now" - data, scale = self.get_plain() - data_len = data.shape[dim] - scale_len = scale.shape[dim] - ratio = data_len / scale_len - start_scale = int(start / ratio) - end_scale = int(end / ratio) - - data = aten.slice.Tensor(data, dim, start, end, step) - scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - # this is to handle padding - data, scale = self._layout.post_process(data, scale, self.block_size) - sliced = self.from_plain(data, scale, self._layout) - return return_and_correct_aliasing(func, args, kwargs, sliced) - else: - raise NotImplementedError( - f"{cls.__name__} dispatch: attempting to run {func}, with dim={dim}, that is not supported" - ) - - raise NotImplementedError( - f"{cls.__name__} dispatch: attempting to run {func}, this is not supported" - ) - - __torch_function__ = torch._C._disabled_torch_function_impl - - @property - def block_size(self): - assert len(self.packed_weight.shape) == 2 - weight_shape = self.packed_weight.shape - N = weight_shape[0] - K = weight_shape[1] - groups = self.scales.numel() // N - group_size = K // groups - return (1, group_size) - - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - if isinstance(self._layout, PlainLayout): - # If the layout is PlainLayout, return the packed weight and scales directly - return ( - self.packed_weight, - self.scales, - torch.zeros_like(self.scales), - ) - # Unpack weight by linear(eye(K), packed_weight).t() - packed_w_shape = self.packed_weight.shape - if len(packed_w_shape) == 4: - K = packed_w_shape[1] * packed_w_shape[2] - else: - K = packed_w_shape[1] - x = torch.eye(K).to(torch.float8_e4m3fn) - x_scale = torch.ones(K).float() - w_scale = torch.ones_like(self.scales).float() - plain_weight = torch.ops.torchao.float8_linear_cpu.default( - x, - x_scale, - self.packed_weight, - w_scale, - None, # bias - torch.float, # out_dtype - ) - plain_weight = plain_weight.t().contiguous() - plain_weight = plain_weight.to(torch.float8_e4m3fn) - - if self.scales.dim() == 2: - plain_scales = self.scales - else: - assert self.scales.dim() == 3 - packed_shape = self.scales.shape # [Nc, G, block_n] - plain_scales = ( - self.scales.permute([0, 2, 1]).contiguous().view([-1, packed_shape[1]]) - ) - - return plain_weight, plain_scales, torch.zeros_like(plain_scales) - - -def _aqt_is_float8e4m3(aqt): - """Check if an AffineQuantizedTensor is float8_e4m3fn quantized Tensor""" - return aqt.tensor_impl.dtype == torch.float8_e4m3fn - - -def _float8_linear_cpu_check(input_tensor, weight_tensor, bias): - return ( - TORCH_VERSION_AT_LEAST_2_6 - and is_device(input_tensor.device.type, "cpu") - and is_device(weight_tensor.device.type, "cpu") - and (bias is None or is_device(bias.device.type, "cpu")) - and isinstance(input_tensor, AffineQuantizedTensor) - and _aqt_is_float8e4m3(input_tensor) - and _is_float(input_tensor.dtype) - and isinstance(input_tensor._layout, PlainLayout) - and isinstance(weight_tensor, AffineQuantizedTensor) - and _aqt_is_float8e4m3(weight_tensor) - and _is_float(weight_tensor.dtype) - and isinstance(weight_tensor._layout, Float8DynamicActFloat8WeightCPULayout) - ) - - -def _float8_linear_cpu_impl(input_tensor, weight_tensor, bias): - assert TORCH_VERSION_AT_LEAST_2_6, ( - f"Requires PyTorch version at least 2.6, but got: {torch.__version__}" - ) - assert is_device(input_tensor.device.type, "cpu"), ( - f"For CPU device only but got: {input_tensor.device}" - ) - assert weight_tensor.block_size[0] == 1, ( - f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" - ) - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( - f"need input_tensor shape: {input_tensor.shape} final" - f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " - ) - - act_mat = input_tensor - act = act_mat.tensor_impl.int_data - act_scales = act_mat.tensor_impl.scale - - packed_weight = weight_tensor.tensor_impl.packed_weight - wei_scales = weight_tensor.tensor_impl.scales - - orig_act_size = act_mat.size() - orig_dtype = act_mat.dtype - - # reshape to 2D - act = act.reshape(-1, act.shape[-1]) - - y = torch.ops.torchao.float8_linear_cpu.default( - act.contiguous(), - act_scales, - packed_weight, - wei_scales, - bias.float() if bias is not None else bias, # requires bias to be float - torch.float, # out_dtype - ) - - # remove out_feature padding - orig_out_features = weight_tensor.shape[-2] - y = y[:, :orig_out_features] - y = y.reshape(*orig_act_size[:-1], orig_out_features) - - return y.to(orig_dtype) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 90e42747b4..28de68a62e 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -89,6 +89,7 @@ quantize_affine, ) from .quantize_.workflows import ( + Float8OpaqueTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -168,6 +169,7 @@ "IntxUnpackedToInt8Tensor", "Float8Tensor", "Int4OpaqueTensor", + "Float8OpaqueTensor", # smooth quant - subject to change "get_scale", "SmoothFakeDynQuantMixin", diff --git a/torchao/quantization/observer.py b/torchao/quantization/observer.py index 6d928a4477..bff189df18 100644 --- a/torchao/quantization/observer.py +++ b/torchao/quantization/observer.py @@ -15,8 +15,10 @@ from .granularity import ( Granularity, PerAxis, + PerGroup, PerRow, PerTensor, + PerToken, ) from .quant_primitives import ( MappingType, @@ -78,8 +80,13 @@ def get_block_size( block_size = list(input_shape) block_size[granularity.axis] = 1 return tuple(block_size) - elif isinstance(granularity, PerRow): + elif isinstance(granularity, (PerRow, PerToken)): return (1,) * (len(input_shape) - 1) + (input_shape[-1],) + elif isinstance(granularity, PerGroup): + assert input_shape[-1] % granularity.group_size == 0, ( + f"Group size {granularity.group_size} does not divide input shape {input_shape}" + ) + return (1,) * (len(input_shape) - 1) + (granularity.group_size,) raise ValueError(f"Unsupported Granularity: {granularity}") diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3d99cd1a9c..df73d48f99 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -71,6 +71,7 @@ PackingFormat, ) from torchao.quantization.quantize_.workflows import ( + Float8OpaqueTensor, Float8Tensor, Int4MarlinSparseTensor, Int4OpaqueTensor, @@ -1708,6 +1709,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True version: int = 2 + packing_format: PackingFormat = PackingFormat.PLAIN def __post_init__(self): torch._C._log_api_usage_once( @@ -1733,6 +1735,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_value_lb = config.activation_value_lb activation_value_ub = config.activation_value_ub kernel_preference = config.kernel_preference + packing_format = config.packing_format # Ensure works on device activation_granularity, weight_granularity = granularity @@ -1790,14 +1793,22 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): kernel_preference=kernel_preference, ) - quantized_weight = Float8Tensor.from_hp( - weight, - float8_dtype=weight_dtype, - granularity=weight_granularity, - mm_config=mm_config, - kernel_preference=kernel_preference, - act_quant_kwargs=act_quant_kwargs, - ) + if packing_format == PackingFormat.PLAIN: + quantized_weight = Float8Tensor.from_hp( + weight, + float8_dtype=weight_dtype, + granularity=weight_granularity, + mm_config=mm_config, + kernel_preference=kernel_preference, + act_quant_kwargs=act_quant_kwargs, + ) + elif packing_format == PackingFormat.OPAQUE: + block_size = get_block_size(weight.shape, weight_granularity) + quantized_weight = Float8OpaqueTensor.from_hp( + weight, + block_size=block_size, + act_quant_kwargs=act_quant_kwargs, + ) return quantized_weight diff --git a/torchao/quantization/quantize_/workflows/__init__.py b/torchao/quantization/quantize_/workflows/__init__.py index 863608050e..0eff9d9f1f 100644 --- a/torchao/quantization/quantize_/workflows/__init__.py +++ b/torchao/quantization/quantize_/workflows/__init__.py @@ -1,3 +1,6 @@ +from .float8.float8_opaque_tensor import ( + Float8OpaqueTensor, +) from .float8.float8_tensor import ( Float8Tensor, QuantizeTensorToFloat8Kwargs, @@ -25,6 +28,7 @@ "Int4Tensor", "Int4PreshuffledTensor", "Int4MarlinSparseTensor", + "Float8OpaqueTensor", "Float8Tensor", "QuantizeTensorToFloat8Kwargs", "IntxOpaqueTensor", diff --git a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py new file mode 100644 index 0000000000..6dc67cf335 --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py @@ -0,0 +1,222 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import List, Optional + +import torch + +from torchao.quantization.granularity import ( + PerGroup, + PerTensor, + PerToken, +) +from torchao.quantization.quant_primitives import ( + _choose_scale_float8, + _quantize_affine_float8, +) +from torchao.utils import ( + TorchAOBaseTensor, +) + +from .float8_tensor import QuantizeTensorToFloat8Kwargs + +__all__ = [ + "Float8OpaqueTensor", +] + +aten = torch.ops.aten + + +class Float8OpaqueTensor(TorchAOBaseTensor): + """ + Float8 dynamic activation float8 weight on CPU. The weight tensor is reordered to a blocked layout + for better memory locality from [N, K] to [N/block_n, K/block_k, block_k, block_n], where block_n = 32 + and block_k depends on group-size for quantization (=32/64/128). And the innermost block with shape + [block_k, block_n] may be further reordered to VNNI layout depending on supported CPU ISA. + + Tensor Attributes: + qdata: Reordered float8 weight on CPU with shape = [N/block_n, K/block_k, block_k, block_n]. + scale: Scale tensor for weight, dtype = float32. For per-group/row quantization, shape = + [N / block_n, num_groups, block_n]. For per-tensor quantization, shape = [1]. + + Non-Tensor Attributes: + block_size: the block size for quantization, representing the granularity. for groupwise quantization, + block_size is (1, group_size). we only support group_size = 32/64/128. For per-row + quantization, blocks_size is (1, K). For per-tensor quantization, block_size is (N, K). + shape: shape of the original Tensor + act_quant_args: the kwargs for from_hp + """ + + tensor_data_names = ["qdata", "scale"] + tensor_attribute_names = ["block_size", "shape", "act_quant_args"] + + def __new__( + cls, + qdata, + scale, + block_size, + act_quant_args, + ): + shape = qdata.shape + kwargs = {} + kwargs["device"] = qdata.device + kwargs["dtype"] = scale.dtype + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + qdata: torch.Tensor, + scale: torch.Tensor, + block_size: List[int], + act_quant_args: Optional[QuantizeTensorToFloat8Kwargs] = None, + ): + self.qdata = qdata + self.scale = scale + self.block_size = block_size + self.act_quant_args = act_quant_args + + def _quantization_type(self): + return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, {self.act_quant_kwargs=}" + + @classmethod + def from_hp( + cls, + hp_tensor: torch.Tensor, + block_size: List[int], + act_quant_args: Optional[QuantizeTensorToFloat8Kwargs] = None, + ): + assert hp_tensor.ndim == 2 and hp_tensor.device.type == "cpu", ( + f"Expecting 2D tensor on CPU, but got: {hp_tensor.shape} on {hp_tensor.device.type}" + ) + assert len(block_size) == hp_tensor.ndim + N = hp_tensor.size(0) + K = hp_tensor.size(-1) + assert (block_size[0] == 1 or block_size[0] == N) and block_size[1] in ( + 32, + 64, + 128, + K, + ), f"Unsupported block_size: {block_size} for tensor shape {hp_tensor}" + assert N % 32 == 0, ( + f"Expecting out_features {N} to be multiple of 32, but got {N}" + ) + assert K % block_size[1] == 0, ( + f"Expecting in_features {K} to be multiple of group_size {block_size[1]}, but got {K}" + ) + scale = _choose_scale_float8( + hp_tensor, + float8_dtype=torch.float8_e4m3fn, + block_size=block_size, + ) + data = _quantize_affine_float8(hp_tensor, scale, torch.float8_e4m3fn) + # Pack weight from [N, K] to [N / block_n, K / block_k, block_k, block_n]. + # Pack scales from [N, num_groups] to [N / block_n, num_groups, block_n]. + packed_weight, packed_scale = torch.ops.torchao.float8_linear_prepack_cpu( + data, scale + ) + + return Float8OpaqueTensor( + qdata=packed_weight, + scale=packed_scale, + block_size=block_size, + act_quant_args=act_quant_args, + ) + + +implements = Float8OpaqueTensor.implements + + +@implements([torch.nn.functional.linear, aten.linear.default]) +def _(func, types, args, kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + assert input_tensor.device.type == "cpu", ( + f"For CPU device only but got: {input_tensor.device}" + ) + assert isinstance(weight_tensor, Float8OpaqueTensor), ( + f"Expected weight_tensor to be Float8OpaqueTensor, got: {type(weight_tensor)}" + ) + assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" + ) + + act_mat = input_tensor.contiguous() + packed_weight = weight_tensor.qdata + scale = weight_tensor.scale + + orig_act_size = act_mat.size() + orig_dtype = act_mat.dtype + # reshape to 2D + act_mat = act_mat.reshape(-1, act_mat.shape[-1]) + + # activation float8 quantization + if ( + weight_tensor.act_quant_args is not None + and weight_tensor.act_quant_args.granularity is not None + ): + granularity = weight_tensor.act_quant_args.granularity + if isinstance(granularity, PerTensor): + act_scale = _choose_scale_float8( + act_mat, + float8_dtype=torch.float8_e4m3fn, + block_size=list(act_mat.shape), + ) + elif isinstance(granularity, PerToken): + act_scale = _choose_scale_float8( + act_mat, + float8_dtype=torch.float8_e4m3fn, + block_size=[1, act_mat.size(-1)], + ) + elif isinstance(granularity, PerGroup): + group_size = granularity.group_size + if weight_tensor.block_size[1] < weight_tensor.size(-1): + # weight_tensor is also per group quantized + assert weight_tensor.block_size[1] == group_size, ( + "input and weight should have the same group size but got" + f" {weight_tensor.block_size[1]} and {group_size}" + ) + act_scale = _choose_scale_float8( + act_mat, + float8_dtype=torch.float8_e4m3fn, + block_size=[1, group_size], + ) + else: + raise ValueError( + f"Unsupported activation quantization granularity: {granularity}" + ) + act_mat = _quantize_affine_float8(act_mat, act_scale, torch.float8_e4m3fn) + else: + raise NotImplementedError( + "Activation quantization args not provided for Float8OpaqueTensor" + ) + + # groupwise int4 quantization + y = torch.ops.torchao.float8_linear_cpu.default( + act_mat, + act_scale, + packed_weight, + scale, + bias.float() if bias is not None else bias, # requires bias to be float + torch.float, # out_dtype + ) + + # remove out_feature padding + orig_out_features = weight_tensor.shape[-2] + y = y[:, :orig_out_features] + y = y.reshape(*orig_act_size[:-1], orig_out_features) + + return y.to(orig_dtype) + + +Float8OpaqueTensor.__module__ = "torchao.quantization" + +# Allow a model with Float8OpaqueTensor weights to be loaded with `weights_only=True` +torch.serialization.add_safe_globals([Float8OpaqueTensor]) From afcee6bc7ce8db4fb1601637c25595129cef0d3d Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 9 Sep 2025 14:27:13 +0000 Subject: [PATCH 07/15] Update kernel implementation --- .../float8/test_float8_opaque_tensor.py | 259 ++++++++++++++ .../test_dynamic_float8_linear_cpu.py | 159 --------- torchao/csrc/cpu/float8_linear.cpp | 323 +++++++++++------- torchao/float8/inference.py | 23 +- torchao/float8/types.py | 4 +- torchao/ops.py | 4 +- torchao/quantization/quant_api.py | 27 +- .../workflows/float8/float8_opaque_tensor.py | 90 ++--- 8 files changed, 550 insertions(+), 339 deletions(-) create mode 100644 test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py delete mode 100644 test/quantization/test_dynamic_float8_linear_cpu.py diff --git a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py new file mode 100644 index 0000000000..e7f83d3533 --- /dev/null +++ b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import tempfile +import unittest + +import torch +from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import ( + TestCase, + run_tests, +) + +from torchao import quantize_ +from torchao.quantization import PerGroup, PerRow, PerTensor +from torchao.quantization.quant_api import ( + Float8DynamicActivationFloat8WeightConfig, +) +from torchao.quantization.utils import compute_error +from torchao.utils import ( + torch_version_at_least, +) + + +def get_config(granularity): + return Float8DynamicActivationFloat8WeightConfig( + activation_dtype=torch.float8_e4m3fn, + granularity=granularity, + packing_format="opaque", + ) + + +class ToyLinearModel(torch.nn.Module): + def __init__(self, K=64, N=32, bias=False): + super().__init__() + self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) + self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) + + def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): + return ( + torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device) + * 0.1, + ) + + def forward(self, x): + x = self.linear1(x) + x = self.linear2(x) + return x + + +class TestDynamicFloat8Linear(TestCase): + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config(PerRow()), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize( + "granularity", + [ + (PerTensor(), PerTensor()), + (PerTensor(), PerRow()), + (PerTensor(), PerGroup(64)), + ], + ) + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 128]) + def test_dynamic_float8_linear_per_tensor_cpu( + self, granularity, dtype, x_dim, bias, bs + ): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config(granularity), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + def test_dynamic_float8_linear_ref_cpu(self, dtype, x_dim, bias): + device = "cpu" + # the shape is not supported by cpp kernel, so the ref path will be used. + m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) + bs = 4 + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config(PerRow()), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + @common_utils.parametrize("group_size", [32, 64, 128]) + def test_dynamic_float8_linear_per_group_cpu( + self, dtype, x_dim, bias, bs, group_size + ): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config([PerRow(), PerGroup(group_size)]), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @unittest.skipIf(not torch_version_at_least("2.6.0"), "Test only enabled for 2.6+") + @common_utils.parametrize("dtype", [torch.bfloat16, torch.half]) + @common_utils.parametrize("x_dim", [2, 3]) + @common_utils.parametrize("bias", [True, False]) + @common_utils.parametrize("bs", [1, 160]) + @common_utils.parametrize("group_size", [32, 64, 128]) + def test_dynamic_float8_linear_per_group_act_cpu( + self, dtype, x_dim, bias, bs, group_size + ): + device = "cpu" + m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) + example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) + if x_dim == 3: + example_inputs = (example_inputs[0].unsqueeze(0),) + y = m(*example_inputs) + + with torch.no_grad(): + quantize_( + m, + get_config([PerGroup(group_size), PerGroup(group_size)]), + ) + y1 = m(*example_inputs) + assert compute_error(y, y1) > 20 + y2, code = torch._inductor.utils.run_and_get_code( + torch.compile(m, fullgraph=True, dynamic=True), + *example_inputs, + ) + # ensure the expected op is in the code + assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] + assert compute_error(y, y2) > 20 + + @unittest.skipIf( + "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), + reason="cpp kernels not built", + ) + @common_utils.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) + def test_module_path(self, dtype): + linear = torch.nn.Linear(128, 256, dtype=dtype) + quantize_(linear, get_config(PerRow())) + self.assertEqual( + str(type(linear.weight)), + "", + ) + + with tempfile.NamedTemporaryFile() as f: + torch.save(linear.state_dict(), f) + f.seek(0) + state_dict = torch.load(f) + self.assertEqual( + str(type(state_dict["weight"])), + "", + ) + + +common_utils.instantiate_parametrized_tests(TestDynamicFloat8Linear) + + +if __name__ == "__main__": + run_tests() diff --git a/test/quantization/test_dynamic_float8_linear_cpu.py b/test/quantization/test_dynamic_float8_linear_cpu.py deleted file mode 100644 index 7f81965a0e..0000000000 --- a/test/quantization/test_dynamic_float8_linear_cpu.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. - -import copy -import unittest - -import torch -from torch.testing._internal import common_utils -from torch.testing._internal.common_utils import ( - TestCase, - run_tests, -) - -from torchao import quantize_ -from torchao.dtypes import ( - Float8DynamicActFloat8WeightCPULayout, - PlainLayout, -) -from torchao.quantization import PerRow -from torchao.quantization.quant_api import ( - Float8DynamicActivationFloat8WeightConfig, -) -from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_6, -) - - -class ToyLinearModel(torch.nn.Module): - def __init__(self, K=64, N=32, bias=False): - super().__init__() - self.linear1 = torch.nn.Linear(K, N, bias=bias).to(torch.float) - self.linear2 = torch.nn.Linear(N, K, bias=bias).to(torch.float) - - def example_inputs(self, batch_size=1, dtype=torch.float, device="cpu"): - return ( - torch.rand(batch_size, self.linear1.in_features, dtype=dtype, device=device) - * 0.1, - ) - - def forward(self, x): - x = self.linear1(x) - x = self.linear2(x) - return x - - -class TestDynamicFloat8Linear(TestCase): - @unittest.skipIf( - "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), - reason="cpp kernels not built", - ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") - @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) - @common_utils.parametrize("x_dim", [2, 3]) - @common_utils.parametrize("bias", [True, False]) - @common_utils.parametrize("bs", [1, 160]) - def test_dynamic_float8_linear_cpu(self, dtype, x_dim, bias, bs): - device = "cpu" - m = ToyLinearModel(256, 256, bias=bias).eval().to(dtype).to(device) - m2 = copy.deepcopy(m) - example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) - if x_dim == 3: - example_inputs = (example_inputs[0].unsqueeze(0),) - - with torch.no_grad(): - quantize_( - m, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - layout=Float8DynamicActFloat8WeightCPULayout(), - ), - ) - y, code = torch._inductor.utils.run_and_get_code( - torch.compile(m, fullgraph=True, dynamic=True), - *example_inputs, - ) - # ensure the expected op is in the code - assert "torch.ops.torchao.float8_linear_cpu.default" in code[0] - quantize_( - m2, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - layout=PlainLayout(), - ), - ) - torch._dynamo.reset() # may segfault without this - y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) - atol, rtol = 1e-4, 1e-6 - if dtype == torch.bfloat16: - atol, rtol = 1.6e-2, 3e-3 - elif dtype == torch.half: - atol, rtol = 6e-3, 2e-3 - assert torch.allclose(y, y2, atol=atol, rtol=rtol) - # Test get_plain by dequantize() - dqw1 = m.linear1.weight.original_weight_tensor.dequantize() - dqw2 = m.linear2.weight.original_weight_tensor.dequantize() - dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() - dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() - assert torch.allclose(dqw1, dqw1_ref) - assert torch.allclose(dqw2, dqw2_ref) - - @unittest.skipIf( - "CPU" not in torch._C._dispatch_dump("torchao::float8_linear_cpu"), - reason="cpp kernels not built", - ) - @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Test only enabled for 2.6+") - @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) - @common_utils.parametrize("x_dim", [2, 3]) - @common_utils.parametrize("bias", [True, False]) - def test_dynamic_float8_linear_ref_cpu(self, dtype, x_dim, bias): - device = "cpu" - # the shape is not supported by cpp kernel, so the ref path will be used. - m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) - m2 = copy.deepcopy(m) - bs = 4 - example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) - if x_dim == 3: - example_inputs = (example_inputs[0].unsqueeze(0),) - - with torch.no_grad(): - quantize_( - m, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - layout=Float8DynamicActFloat8WeightCPULayout(), - ), - ) - y, code = torch._inductor.utils.run_and_get_code( - torch.compile(m, fullgraph=True, dynamic=True), - *example_inputs, - ) - # ensure the op is not in the code - assert "torch.ops.torchao.float8_linear_cpu.default" not in code[0] - quantize_( - m2, - Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - layout=PlainLayout(), - ), - ) - torch._dynamo.reset() # may segfault without this - y2 = torch.compile(m2, fullgraph=True, dynamic=True)(*example_inputs) - assert torch.allclose(y, y2) - # Test get_plain by dequantize() - dqw1 = m.linear1.weight.original_weight_tensor.dequantize() - dqw2 = m.linear2.weight.original_weight_tensor.dequantize() - dqw1_ref = m2.linear1.weight.original_weight_tensor.dequantize() - dqw2_ref = m2.linear2.weight.original_weight_tensor.dequantize() - assert torch.allclose(dqw1, dqw1_ref) - assert torch.allclose(dqw2, dqw2_ref) - - -common_utils.instantiate_parametrized_tests(TestDynamicFloat8Linear) - - -if __name__ == "__main__": - run_tests() diff --git a/torchao/csrc/cpu/float8_linear.cpp b/torchao/csrc/cpu/float8_linear.cpp index 586c3e9439..2de50e757a 100644 --- a/torchao/csrc/cpu/float8_linear.cpp +++ b/torchao/csrc/cpu/float8_linear.cpp @@ -2,6 +2,7 @@ #include #include #include +#include "dispatcher.h" namespace torchao { @@ -9,6 +10,10 @@ namespace { #define BLOCK_N 32 +#define PER_TENSOR 1 +#define PER_ROW 2 +#define PER_GROUP 3 + static bool cpublas_checked = false; static bool cpublas_can_pack = false; @@ -17,7 +22,11 @@ bool cpublas_could_pack() { if (cpublas_checked) { return cpublas_can_pack; } +#ifdef CPUBLAS_BRGEMM_F8F8F32 + cpublas_can_pack = at::native::cpublas::could_pack(at::kFloat8_e4m3fn); +#else cpublas_can_pack = at::native::cpublas::could_pack(at::kBFloat16); +#endif cpublas_checked = true; return cpublas_can_pack; } @@ -33,16 +42,20 @@ float8_linear_prepack_impl( // scales shape = [N, G] TORCH_CHECK(weight.dim() == 2, "Float8 linear CPU: Weight should be a 2D tensor for packing"); - TORCH_CHECK(weight.size(1) % 2 == 0, - "Float8 linear CPU: Weight should have even number of columns for packing"); + int N = weight.size(0); + int K = weight.size(1); + constexpr int block_n = BLOCK_N; + // Case to fall back + if (N % block_n != 0 || K % 32 != 0) { + return std::make_tuple(weight, scales); + } auto new_scales = scales; - if (new_scales.dim() == 1) { + bool is_per_tensor = new_scales.numel() == 1; + if (new_scales.dim() == 1 && !is_per_tensor) { new_scales.unsqueeze_(1); } new_scales = new_scales.to(at::kFloat); - int N = weight.size(0); - int K = weight.size(1); int G = scales.size(1); TORCH_CHECK(K % G == 0, "K should be divisible by num_groups"); int group_size = K / G; @@ -52,8 +65,6 @@ float8_linear_prepack_impl( } TORCH_CHECK(block_k > 0 && block_k <= group_size, "Float8 linear CPU: Invalid block_k size, should be in (0, group_size]"); - constexpr int block_n = BLOCK_N; - TORCH_CHECK(N % block_n == 0, "N should be divisible by 32"); int Nc = N / block_n; int Kc = K / block_k; @@ -62,11 +73,15 @@ float8_linear_prepack_impl( auto weight_view = weight.view({Nc, block_n, Kc, block_k}); at::Tensor weight_reordered = weight_view.permute({0, 2, 3, 1}).contiguous(); at::Tensor blocked_weight; - at::Tensor blocked_scales = new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); + at::Tensor blocked_scales = is_per_tensor ? new_scales.view({1}) : new_scales.view({Nc, block_n, G}).permute({0, 2, 1}).contiguous(); #if defined(CPU_CAPABILITY_AVX512) if (cpublas_could_pack()) { +#ifdef CPUBLAS_BRGEMM_F8F8F32 + constexpr int vnni_size = 4; // for fp8 +#else constexpr int vnni_size = 2; // for float16 +#endif blocked_weight = at::empty({Nc, Kc, block_k, block_n}, weight.options()); auto weight_ptr = reinterpret_cast(weight_reordered.data_ptr()); auto blocked_weight_ptr = reinterpret_cast(blocked_weight.data_ptr()); @@ -192,8 +207,10 @@ static void _convert_A_to_bf16( } } -template -static void _dequant_and_store( +// accumulate and store result to buffer +// if act/wei are per_group quantized, apply scales +template +static void _store_result( float* __restrict__ output, const float* __restrict__ input, const float* __restrict__ scale_a, @@ -202,25 +219,41 @@ static void _dequant_and_store( int ldi, int ldo, int ldsa = 1) { + float a_scale, b_scale; + __m512 va_scale; + __m512 vb_scale; for (int m = 0; m < M; ++m) { - float a_scale = *(scale_a + m * ldsa); - __m512 va_scale = _mm512_set1_ps(a_scale); + if constexpr (act_quant_mode == PER_GROUP) { + a_scale = *(scale_a + m * ldsa); + va_scale = _mm512_set1_ps(a_scale); + } int n = 0; #pragma GCC unroll 2 for (; n < N; n += 16) { __m512 vc_f = _mm512_loadu_ps(input + m * ldi + n); - __m512 vc_f_mul = _mm512_mul_ps(vc_f, va_scale); - __m512 vb_s = _mm512_loadu_ps(scale_b + n); - vc_f_mul = _mm512_mul_ps(vc_f_mul, vb_s); + if constexpr (act_quant_mode == PER_GROUP) { + vc_f = _mm512_mul_ps(vc_f, va_scale); + } + if constexpr (wei_quant_mode == PER_GROUP) { + vb_scale = _mm512_loadu_ps(scale_b + n); + vc_f = _mm512_mul_ps(vc_f, vb_scale); + } if constexpr (accum) { __m512 vo = _mm512_loadu_ps(output + m * ldo + n); - _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f_mul)); + _mm512_storeu_ps(output + m * ldo + n, _mm512_add_ps(vo, vc_f)); } else { - _mm512_storeu_ps(output + m * ldo + n, vc_f_mul); + _mm512_storeu_ps(output + m * ldo + n, vc_f); } } for (; n < N; ++n) { - float dq_val = input[m * ldi + n] * a_scale * scale_b[n]; + float dq_val = input[m * ldi + n]; + if constexpr (act_quant_mode == PER_GROUP) { + dq_val = dq_val * a_scale; + } + if constexpr (wei_quant_mode == PER_GROUP) { + b_scale = scale_b[n]; + dq_val = dq_val * b_scale; + } if constexpr (accum) { output[m * ldo + n] += dq_val; } else { @@ -254,8 +287,8 @@ static void _convert_A_to_bf16( } #endif -template -void _dequant_gemm_accum( +template +void _micro_gemm( float* C, const at::Float8_e4m3fn* A, const float* scales_a, @@ -264,16 +297,35 @@ void _dequant_gemm_accum( int64_t M, int64_t K, int64_t lda, - int64_t ldc) { - // Compute GEMM fp8 * fp8 -> fp32 - // Then apply scales and store results + int64_t ldc, + int64_t ldsa) { + // If FP8 brgemm is not available, convert A/B to bf16 for computation + // Compute GEMM fp8 * fp8 -> fp32 (or bf16 * bf16 -> fp32) + // If per_group quant, apply scales. Otherwise, don't apply scales here + // Finally accumulate and store results +#ifndef CPUBLAS_BRGEMM_F8F8F32 at::BFloat16 dqB[K * N]; _convert_B_to_bf16(B, dqB, K * N); at::BFloat16 dqA[M * K]; _convert_A_to_bf16(A, dqA, M, K, lda); +#endif #if defined(CPU_CAPABILITY_AVX512) if constexpr (cpublas_can_pack) { float C_f32[M * N]; +#ifdef CPUBLAS_BRGEMM_F8F8F32 + at::native::cpublas::brgemm( + M, + N, + K, + lda /*lda*/, + N /*ldb*/, + N /*ldc*/, + false /* add_C */, + A, + B, + C_f32, + true /* is_vnni */); +#else at::native::cpublas::brgemm( M, N, @@ -286,9 +338,10 @@ void _dequant_gemm_accum( dqB, C_f32, true /* is_vnni */); +#endif _mm_prefetch(B + N * K, _MM_HINT_T0); _mm_prefetch(A + K, _MM_HINT_T0); - _dequant_and_store( + _store_result( C, C_f32, scales_a, @@ -296,7 +349,7 @@ void _dequant_gemm_accum( M, N /*ldi*/, ldc, - 1 /*ldsa*/); + ldsa); } else #endif { @@ -304,93 +357,96 @@ void _dequant_gemm_accum( for (int64_t j = 0; j < N; ++j) { float sum = 0; for (int64_t k = 0; k < K; ++k) { +#ifdef CPUBLAS_BRGEMM_F8F8F32 + sum += ((float)A[i * lda + k] * (float)B[k * N + j]); +#else sum += ((float)dqA[i * K + k] * dqB[k * N + j]); +#endif + } + if constexpr (act_quant_mode == PER_GROUP) { + sum *= scales_a[i * ldsa]; + } + if constexpr (wei_quant_mode == PER_GROUP) { + sum *= scales_b[j]; } - C[i * ldc + j] += sum * scales_a[i] * scales_b[j]; + C[i * ldc + j] += sum; } } } } -template -inline void copy_bias(const float* bias_ptr, float* y_buf, int64_t m) { - if (bias_ptr) { - for (int i = 0; i < m; ++i) { - int j = 0; +// Store result to output buffer with dtype conversion +// If act/wei are per_row or per_tensor quantized, apply scales +// If bias is not null, add bias +template +inline void store_out( + const float* y_buf, + out_dtype* c_ptr, + int64_t M, + int64_t lda, + const float* scales_a, + const float* scales_b, + const float* bias) { + float a_scale = 1.0, b_scale = 1.0; #if defined(CPU_CAPABILITY_AVX512) -#pragma GCC unroll 2 - for (; j < N; j += 16) { - __m512 bias_vec = _mm512_loadu_ps(bias_ptr + j); - _mm512_storeu_ps(y_buf + i * N + j, bias_vec); - } + __m512 va_scale, vb_scale; #endif - for (; j < N; ++j) { - y_buf[i * N + j] = bias_ptr[j]; - } - } - } else { // initialize to zero - for (int i = 0; i < m; ++i) { - int j = 0; + if constexpr (act_quant_mode == PER_TENSOR) { + a_scale = *scales_a; + } + if constexpr (wei_quant_mode == PER_TENSOR) { + b_scale = *scales_b; #if defined(CPU_CAPABILITY_AVX512) -#pragma GCC unroll 2 - for (; j < N; j += 16) { - __m512 zero_vec = _mm512_setzero_ps(); - _mm512_storeu_ps(y_buf + i * N + j, zero_vec); - } + vb_scale = _mm512_set1_ps(b_scale); #endif - for (; j < N; ++j) { - y_buf[i * N + j] = 0; - } - } } -} - -template -inline void store_out(const float* y_buf, out_dtype* c_ptr, int64_t m, /* int64_t n, */ int64_t lda) { - for (int i = 0; i < m; ++i) { + for (int i = 0; i < M; ++i) { + if constexpr (act_quant_mode == PER_ROW) { + a_scale = *(scales_a + i); + } int j = 0; - if constexpr (std::is_same::value) { #if defined(CPU_CAPABILITY_AVX512) + if constexpr (act_quant_mode != PER_GROUP) { + va_scale = _mm512_set1_ps(a_scale); + } #pragma GCC unroll 2 - for (; j < N; j += 16) { - __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); - _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + for (; j < N; j += 16) { + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m512 bias_vec = bias ? _mm512_loadu_ps(bias + j) : _mm512_setzero_ps(); + if constexpr (act_quant_mode != PER_GROUP) { + y_vec = _mm512_mul_ps(y_vec, va_scale); } -#endif - for (; j < N; ++j) { - c_ptr[i * lda + j] = y_buf[i * N + j]; + if constexpr (wei_quant_mode == PER_ROW) { + vb_scale = _mm512_loadu_ps(scales_b + j); } - } else if constexpr (std::is_same::value) { -#if defined(CPU_CAPABILITY_AVX512) -#pragma GCC unroll 2 - for (; j < N; j += 16) { - __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + if constexpr (wei_quant_mode != PER_GROUP) { + y_vec = _mm512_mul_ps(y_vec, vb_scale); + } + y_vec = _mm512_add_ps(y_vec, bias_vec); + if constexpr (std::is_same::value) { + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } else if constexpr (std::is_same::value) { __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); - } -#endif - for (; j < N; ++j) { - c_ptr[i * lda + j] = at::BFloat16(y_buf[i * N + j]); - } - } else if constexpr (std::is_same::value) { -#if defined(CPU_CAPABILITY_AVX512) -#pragma GCC unroll 2 - for (; j < N; j += 16) { - __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + } else if constexpr (std::is_same::value) { __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); + } else { + TORCH_CHECK(false, "Unsupported output dtype"); } -#endif - for (; j < N; ++j) { - c_ptr[i * lda + j] = at::Half(y_buf[i * N + j]); + } +#else + for (; j < N; ++j) { + if constexpr (wei_quant_mode == PER_ROW) { + b_scale = scales_b[j]; } - } else { - TORCH_CHECK(false, "Unsupported output dtype"); + c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); } - } +#endif + } // for M } -template +template void _float8_linear_impl( const at::Tensor& input, const at::Tensor& input_scales, @@ -403,7 +459,6 @@ void _float8_linear_impl( int64_t K = input.size(-1); auto input_view = input.view({-1, K}); int64_t M = input_view.size(0); - TORCH_CHECK(input_scales.numel() == M, "Float8 linear: unexpected input scales shape"); // weight shape = [Nc, Kc, block_k, block_n] // scales shape = [Nc, G, block_n] @@ -430,12 +485,14 @@ void _float8_linear_impl( int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc; // scales shape = [Nc, G, block_n] - int64_t num_groups = weight_scales.size(1); + int64_t num_groups = wei_quant_mode == PER_TENSOR ? 1 : weight_scales.size(1); TORCH_CHECK(K % num_groups == 0, "K should be divisible by num_groups"); int64_t group_size = K / num_groups; TORCH_CHECK(group_size % block_k == 0, "Float8 linear: group_size should be divisible by block_k"); int64_t block_per_group = group_size / block_k; + TORCH_CHECK(input_scales.numel() == 1 || input_scales.numel() == M || input_scales.numel() == M * num_groups, "Float8 linear: unexpected input scales shape"); + auto ldsa = act_quant_mode == PER_TENSOR ? 0 : act_quant_mode == PER_ROW ? 1 : num_groups; const at::Float8_e4m3fn* a_ptr = input_view.data_ptr(); const float* a_scales_ptr = input_scales.data_ptr(); @@ -452,28 +509,36 @@ void _float8_linear_impl( for (int mci = mc; mci < mc_end; ++mci) { int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; - alignas(64) float y_buf[m_size][block_n]; - // copy bias to y_buf if bias is not None - auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; - copy_bias(bias_data, y_buf[0], m_size); + alignas(64) float y_buf[m_size][block_n] = {0}; for (int kci = 0; kci < Kc; ++kci) { - _dequant_gemm_accum( + auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group; + auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n; + _micro_gemm( y_buf[0] /*C*/, a_ptr + mci * block_m * K + kci * block_k /*A*/, - a_scales_ptr + mci * block_m /*scales_a*/, + scales_a /*scales_a*/, b_ptr + (nc * Kc + kci) * block_n * block_k /*B*/, - b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n /*scales_b*/, + scales_b /*scales_b*/, m_size /*M*/, block_k /*K*/, K /*lda*/, - block_n /*ldc*/); + block_n /*ldc*/, + ldsa /*ldsa*/); } // store y_buf to output with dtype conversion - store_out( + auto scales_a = act_quant_mode == PER_TENSOR ? a_scales_ptr : + act_quant_mode == PER_ROW ? a_scales_ptr + mci * block_m : nullptr; + auto scales_b = wei_quant_mode == PER_TENSOR ? b_scales_ptr : + wei_quant_mode == PER_ROW ? b_scales_ptr + nc * block_n : nullptr; + auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; + store_out( y_buf[0], c_ptr + mci * block_m * N + nc * block_n, m_size, - N /*lda*/); + N /*lda*/, + scales_a, + scales_b, + bias_data); } } if constexpr (cpublas_can_pack) { @@ -489,29 +554,55 @@ at::Tensor float8_linear_impl( const at::Tensor& weight_scales, const std::optional& bias, at::ScalarType output_dtype) { + int64_t N = weight.dim() == 4 ? weight.size(0) * weight.size(-1) : weight.size(0); + int act_quant_mode = input_scales.numel() == 1 ? PER_TENSOR : + input_scales.numel() == input.numel() / input.size(-1) ? PER_ROW : + PER_GROUP; + int wei_quant_mode = weight_scales.numel() == 1 ? PER_TENSOR : + weight_scales.numel() == N ? PER_ROW : + PER_GROUP; + // Case to fall back + if (weight.dim() == 2) { + TORCH_CHECK(act_quant_mode != PER_GROUP && wei_quant_mode != PER_GROUP, + "FP8 linear: Per-group quantization is not supported in the fallback path"); + auto y_fp32 = at::linear(input.to(at::kFloat).mul(input_scales), weight.to(at::kFloat).mul(weight_scales), bias); + return y_fp32.to(output_dtype); + } + static bool cpublas_can_pack = cpublas_could_pack(); auto out_sizes = input.sizes().vec(); - int64_t N = weight.size(0) * weight.size(-1); out_sizes.back() = N; auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); -#define call__float8_linear_impl(cpublas_can_pack) \ - AT_DISPATCH_FLOATING_TYPES_AND2( \ - at::ScalarType::BFloat16, at::ScalarType::Half, output_dtype, "float8_linear_cpu", [&] { \ - _float8_linear_impl( \ - input, \ - input_scales, \ - weight, \ - weight_scales, \ - bias, \ - output); \ - }); - - if (cpublas_can_pack) { - call__float8_linear_impl(true); - } else { - call__float8_linear_impl(false); - } + product_dispatcher< + std::tuple< + /*output_dtype*/ at::ScalarType, + /*cpublas_can_pack*/ bool, + /*act_quant_mode*/ int, + /*wei_quant_mode*/ int>, + std::tuple< + enumerate_dispatcher, + enumerate_dispatcher, + enumerate_dispatcher, + enumerate_dispatcher>>:: + call( + std::make_tuple(output_dtype, cpublas_can_pack, act_quant_mode, wei_quant_mode), + [&](auto tuple) { + constexpr auto o_dtype = std::get<0>(tuple); + using out_dtype = typename c10::impl::ScalarTypeToCPPType::type; + constexpr bool cpublas_can_pack_v = std::get<1>(tuple); + constexpr int act_quant_mode_v = std::get<2>(tuple); + constexpr int wei_quant_mode_v = std::get<3>(tuple); + _float8_linear_impl( + input, + input_scales, + weight, + weight_scales, + bias, + output); + }, + [](auto tuple) { TORCH_CHECK(false, "Not implemented for this configuration"); }); + return output; } diff --git a/torchao/float8/inference.py b/torchao/float8/inference.py index f15d38576c..4fa69ecd78 100644 --- a/torchao/float8/inference.py +++ b/torchao/float8/inference.py @@ -14,6 +14,7 @@ from torchao.float8.float8_utils import is_row_major, pad_tensor_for_matmul from torchao.float8.types import FP8Granularity from torchao.quantization.granularity import ( + PerGroup, PerRow, PerTensor, ) @@ -205,22 +206,19 @@ def _normalize_granularity( ] ], ) -> Tuple[FP8Granularity, FP8Granularity]: + supported_granularities = (PerTensor, PerRow, PerGroup) processed_granularity = None if granularity is None: processed_granularity = (PerTensor(), PerTensor()) - elif isinstance(granularity, (PerTensor, PerRow)): + elif isinstance(granularity, supported_granularities): processed_granularity = (granularity, granularity) elif isinstance(granularity, (tuple, list)) and len(granularity) == 2: if not ( - isinstance(granularity[0], (PerTensor, PerRow)) - and isinstance(granularity[1], (PerTensor, PerRow)) + isinstance(granularity[0], supported_granularities) + and isinstance(granularity[1], supported_granularities) ): raise ValueError( - f"Invalid granularity types: {granularity}, only PerTensor or PerRow are supported." - ) - if not isinstance(granularity[0], type(granularity[1])): - raise ValueError( - f"Different granularities for activation and weight are not supported: {granularity}, only PerTensor or PerRow are supported." + f"Invalid granularity types: {granularity}, only PerTensor or PerRow or PerGroup are supported." ) processed_granularity = tuple(granularity) else: @@ -232,6 +230,7 @@ def _normalize_granularity( def _check_hardware_support( granularities: Tuple[FP8Granularity, FP8Granularity], + device: str = "cuda", ) -> None: """ Validate that the hardware supports the requested granularities. @@ -243,12 +242,16 @@ def _check_hardware_support( AssertionError: If hardware doesn't support the requested granularity ValueError: If invalid granularity type is provided """ + supported_granularities = ( + (PerTensor, PerRow, PerGroup) if device == "cpu" else (PerTensor, PerRow) + ) for _granularity in granularities: - if not isinstance(_granularity, (PerTensor, PerRow)): + if not isinstance(_granularity, supported_granularities): raise ValueError( - f"Invalid granularity type: {_granularity}, only PerTensor or PerRow are supported." + f"Invalid granularity type: {_granularity}, only {supported_granularities} are supported." ) + if device != "cpu": assert is_sm_at_least_89() or is_MI300(), ( "Float8 dynamic quantization requires CUDA compute capability ≥8.9 or MI300+." ) diff --git a/torchao/float8/types.py b/torchao/float8/types.py index b332a9629a..63cabc9582 100644 --- a/torchao/float8/types.py +++ b/torchao/float8/types.py @@ -12,8 +12,8 @@ from typing import TYPE_CHECKING, Union if TYPE_CHECKING: - from torchao.quantization.granularity import PerRow, PerTensor + from torchao.quantization.granularity import PerGroup, PerRow, PerTensor # Define FP8Granularity type alias to break circular import dependencies -FP8Granularity = Union["PerTensor", "PerRow"] +FP8Granularity = Union["PerTensor", "PerRow", "PerGroup"] diff --git a/torchao/ops.py b/torchao/ops.py index c8227bdbf0..77475d8185 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -1184,6 +1184,6 @@ def _( bias: Optional[Tensor], out_dtype: torch.dtype, ) -> Tensor: - assert weight.dim() == 4 - N = weight.size(0) * weight.size(3) + assert weight.dim() in (2, 4) + N = weight.size(0) * weight.size(3) if weight.dim() == 4 else weight.size(0) return input.new_empty(*input.shape[:-1], N, dtype=out_dtype) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index df73d48f99..f23b6c05fc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1641,7 +1641,11 @@ def _input_activation_quant_cpu_fp8( activation_dtype: torch.dtype, ): """Dynamic quantize activation to fp8 for CPU.""" - block_size = get_block_size(x.shape, activation_granularity) + if not isinstance(activation_granularity, PerGroup): + block_size = get_block_size(x.shape, activation_granularity) + else: + group_size = activation_granularity.group_size + block_size = (*([1] * (len(x.shape) - 1)), group_size) return to_affine_quantized_floatx( input_float=x, block_size=block_size, @@ -1718,8 +1722,19 @@ def __post_init__(self): if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=True) activation_granularity, weight_granularity = _normalize_granularity( - self.granularity + self.granularity, ) + if self.packing_format == PackingFormat.PLAIN: + assert isinstance(activation_granularity, (PerTensor, PerRow)), ( + f"Unsupported granularity {activation_granularity}, only PerTensor or PerRow are supported." + ) + assert isinstance(weight_granularity, (PerTensor, PerRow)), ( + f"Unsupported granularity {weight_granularity}, only PerTensor or PerRow are supported." + ) + if not isinstance(activation_granularity, type(weight_granularity)): + raise ValueError( + f"Different granularities for activation and weight are not supported: {activation_granularity, weight_granularity}" + ) self.granularity = [activation_granularity, weight_granularity] @@ -1740,13 +1755,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): # Ensure works on device activation_granularity, weight_granularity = granularity is_cpu = weight.device.type == "cpu" - if is_cpu: - assert not ( - isinstance(activation_granularity, PerTensor) - or isinstance(weight_granularity, PerTensor) - ), "PerTensor quantization is not supported for CPU float8 quantization" - else: - _check_hardware_support(granularity) + _check_hardware_support(granularity, weight.device.type) if not is_cpu and not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently diff --git a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py index 6dc67cf335..ae284a3768 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py @@ -11,9 +11,8 @@ from torchao.quantization.granularity import ( PerGroup, - PerTensor, - PerToken, ) +from torchao.quantization.observer import get_block_size from torchao.quantization.quant_primitives import ( _choose_scale_float8, _quantize_affine_float8, @@ -48,20 +47,26 @@ class Float8OpaqueTensor(TorchAOBaseTensor): block_size is (1, group_size). we only support group_size = 32/64/128. For per-row quantization, blocks_size is (1, K). For per-tensor quantization, block_size is (N, K). shape: shape of the original Tensor - act_quant_args: the kwargs for from_hp + act_quant_kwargs: the kwargs for from_hp """ tensor_data_names = ["qdata", "scale"] - tensor_attribute_names = ["block_size", "shape", "act_quant_args"] + tensor_attribute_names = ["block_size", "act_quant_kwargs"] def __new__( cls, qdata, scale, block_size, - act_quant_args, + act_quant_kwargs, ): - shape = qdata.shape + if qdata.ndim == 2: + shape = qdata.shape + else: + assert qdata.ndim == 4 + shape = torch.Size( + [qdata.size(0) * qdata.size(3), qdata.size(1) * qdata.size(2)] + ) kwargs = {} kwargs["device"] = qdata.device kwargs["dtype"] = scale.dtype @@ -73,12 +78,12 @@ def __init__( qdata: torch.Tensor, scale: torch.Tensor, block_size: List[int], - act_quant_args: Optional[QuantizeTensorToFloat8Kwargs] = None, + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, ): self.qdata = qdata self.scale = scale self.block_size = block_size - self.act_quant_args = act_quant_args + self.act_quant_kwargs = act_quant_kwargs def _quantization_type(self): return f"shape={self.shape}, block_size={self.block_size}, device={self.device}, {self.act_quant_kwargs=}" @@ -88,7 +93,7 @@ def from_hp( cls, hp_tensor: torch.Tensor, block_size: List[int], - act_quant_args: Optional[QuantizeTensorToFloat8Kwargs] = None, + act_quant_kwargs: Optional[QuantizeTensorToFloat8Kwargs] = None, ): assert hp_tensor.ndim == 2 and hp_tensor.device.type == "cpu", ( f"Expecting 2D tensor on CPU, but got: {hp_tensor.shape} on {hp_tensor.device.type}" @@ -102,12 +107,29 @@ def from_hp( 128, K, ), f"Unsupported block_size: {block_size} for tensor shape {hp_tensor}" - assert N % 32 == 0, ( - f"Expecting out_features {N} to be multiple of 32, but got {N}" - ) - assert K % block_size[1] == 0, ( - f"Expecting in_features {K} to be multiple of group_size {block_size[1]}, but got {K}" + # assert N % 32 == 0, ( + # f"Expecting out_features {N} to be multiple of 32, but got {N}" + # ) + assert act_quant_kwargs is not None, ( + "Activation quantization args must be provided for Float8OpaqueTensor" ) + act_per_group_quant = isinstance(act_quant_kwargs.granularity, PerGroup) + wei_per_group_quant = block_size[1] < K + if act_per_group_quant: + group_size = act_quant_kwargs.granularity.group_size + if wei_per_group_quant: + # weight_tensor is also per group quantized + assert block_size[1] == group_size, ( + "input and weight should have the same group size but got" + f" {block_size[1]} and {group_size}" + ) + if act_per_group_quant or wei_per_group_quant: + assert N % 32 == 0, ( + f"Expecting out_features {N} to be multiple of 32, but got {N}" + ) + assert K % block_size[1] == 0, ( + f"Expecting in_features {K} to be multiple of group_size {block_size[1]}, but got {K}" + ) scale = _choose_scale_float8( hp_tensor, float8_dtype=torch.float8_e4m3fn, @@ -124,7 +146,7 @@ def from_hp( qdata=packed_weight, scale=packed_scale, block_size=block_size, - act_quant_args=act_quant_args, + act_quant_kwargs=act_quant_kwargs, ) @@ -144,7 +166,8 @@ def _(func, types, args, kwargs): assert isinstance(weight_tensor, Float8OpaqueTensor), ( f"Expected weight_tensor to be Float8OpaqueTensor, got: {type(weight_tensor)}" ) - assert input_tensor.shape[-1] == weight_tensor.shape[1], ( + assert weight_tensor.ndim in [2, 4] + assert input_tensor.size(-1) == weight_tensor.size(-1), ( f"Shapes of input and weight do not match, input:{input_tensor.shape}, weight: {weight_tensor.shape}" ) @@ -159,23 +182,11 @@ def _(func, types, args, kwargs): # activation float8 quantization if ( - weight_tensor.act_quant_args is not None - and weight_tensor.act_quant_args.granularity is not None + weight_tensor.act_quant_kwargs is not None + and weight_tensor.act_quant_kwargs.granularity is not None ): - granularity = weight_tensor.act_quant_args.granularity - if isinstance(granularity, PerTensor): - act_scale = _choose_scale_float8( - act_mat, - float8_dtype=torch.float8_e4m3fn, - block_size=list(act_mat.shape), - ) - elif isinstance(granularity, PerToken): - act_scale = _choose_scale_float8( - act_mat, - float8_dtype=torch.float8_e4m3fn, - block_size=[1, act_mat.size(-1)], - ) - elif isinstance(granularity, PerGroup): + granularity = weight_tensor.act_quant_kwargs.granularity + if isinstance(granularity, PerGroup): group_size = granularity.group_size if weight_tensor.block_size[1] < weight_tensor.size(-1): # weight_tensor is also per group quantized @@ -183,15 +194,12 @@ def _(func, types, args, kwargs): "input and weight should have the same group size but got" f" {weight_tensor.block_size[1]} and {group_size}" ) - act_scale = _choose_scale_float8( - act_mat, - float8_dtype=torch.float8_e4m3fn, - block_size=[1, group_size], - ) - else: - raise ValueError( - f"Unsupported activation quantization granularity: {granularity}" - ) + act_block_size = get_block_size(act_mat.shape, granularity) + act_scale = _choose_scale_float8( + act_mat, + float8_dtype=torch.float8_e4m3fn, + block_size=act_block_size, + ) act_mat = _quantize_affine_float8(act_mat, act_scale, torch.float8_e4m3fn) else: raise NotImplementedError( From 01c47ca3e044596e2578f56ec783dfbb14d14916 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 9 Sep 2025 16:00:56 +0000 Subject: [PATCH 08/15] Fix issues in code --- torchao/csrc/cpu/float8_linear.cpp | 8 +++-- .../workflows/float8/float8_packing_format.py | 32 +++++++++++++++++++ 2 files changed, 37 insertions(+), 3 deletions(-) create mode 100644 torchao/quantization/quantize_/workflows/float8/float8_packing_format.py diff --git a/torchao/csrc/cpu/float8_linear.cpp b/torchao/csrc/cpu/float8_linear.cpp index 2de50e757a..e60320dd27 100644 --- a/torchao/csrc/cpu/float8_linear.cpp +++ b/torchao/csrc/cpu/float8_linear.cpp @@ -502,6 +502,7 @@ void _float8_linear_impl( const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { + auto y_buf = new (std::align_val_t(8)) float[block_m * block_n]; for (const auto i : c10::irange(begin, end)) { int64_t mc = parallel_on_M ? i / Nc : 0; int64_t nc = parallel_on_M ? i % Nc : i; @@ -509,12 +510,12 @@ void _float8_linear_impl( for (int mci = mc; mci < mc_end; ++mci) { int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; - alignas(64) float y_buf[m_size][block_n] = {0}; + memset(y_buf, 0, sizeof(float) * m_size * block_n); for (int kci = 0; kci < Kc; ++kci) { auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group; auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n; _micro_gemm( - y_buf[0] /*C*/, + y_buf /*C*/, a_ptr + mci * block_m * K + kci * block_k /*A*/, scales_a /*scales_a*/, b_ptr + (nc * Kc + kci) * block_n * block_k /*B*/, @@ -532,7 +533,7 @@ void _float8_linear_impl( wei_quant_mode == PER_ROW ? b_scales_ptr + nc * block_n : nullptr; auto bias_data = bias_ptr ? bias_ptr + nc * block_n : nullptr; store_out( - y_buf[0], + y_buf, c_ptr + mci * block_m * N + nc * block_n, m_size, N /*lda*/, @@ -541,6 +542,7 @@ void _float8_linear_impl( bias_data); } } + delete[] y_buf; if constexpr (cpublas_can_pack) { at::native::cpublas::brgemm_release(); } diff --git a/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py new file mode 100644 index 0000000000..ce8912be7e --- /dev/null +++ b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py @@ -0,0 +1,32 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +from enum import Enum + + +# can switch to StrEnum (https://docs.python.org/3/library/enum.html#enum.StrEnum) +# after python 3.10 is end of life (https://devguide.python.org/versions/) +class Float8PackingFormat(str, Enum): + """Packing format for quantized data in Float8 Tensor subclasses in torchao, represents how + the values in quantized data are packed and laid out in memory. + """ + + """ + plain means the format that quantized Tensor data lays out elements in Tensor sequentially, + for example: for a Tensor of shape (4, 6): + a_0_0, a_0_1, ..., a_0_5, + ... + a_3_0, a_3_1, ..., a_3_5 + + """ + PLAIN = "plain" + + """ + Opaque packing format that's used for tensors that does not have a predefined packing format + (that may be decided on hardware, tensor shape, library availability etc.) and it's not + needed for the rest of the system to understand the specific format that's adopted. + """ + OPAQUE = "opaque" From a42d10f2da6a992c07a905ebc90e35da653b0637 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 9 Sep 2025 17:17:20 +0000 Subject: [PATCH 09/15] Move cpp file --- torchao/csrc/cpu/{ => aten_kernels}/dispatcher.h | 0 torchao/csrc/cpu/{ => aten_kernels}/float8_linear.cpp | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename torchao/csrc/cpu/{ => aten_kernels}/dispatcher.h (100%) rename torchao/csrc/cpu/{ => aten_kernels}/float8_linear.cpp (100%) diff --git a/torchao/csrc/cpu/dispatcher.h b/torchao/csrc/cpu/aten_kernels/dispatcher.h similarity index 100% rename from torchao/csrc/cpu/dispatcher.h rename to torchao/csrc/cpu/aten_kernels/dispatcher.h diff --git a/torchao/csrc/cpu/float8_linear.cpp b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp similarity index 100% rename from torchao/csrc/cpu/float8_linear.cpp rename to torchao/csrc/cpu/aten_kernels/float8_linear.cpp From cca41412dcbb5783494db79c77cc5780c6e11710 Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Tue, 9 Sep 2025 06:26:31 -0700 Subject: [PATCH 10/15] Update torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../quantize_/workflows/float8/float8_opaque_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py index ae284a3768..afe666c7e2 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py @@ -206,7 +206,7 @@ def _(func, types, args, kwargs): "Activation quantization args not provided for Float8OpaqueTensor" ) - # groupwise int4 quantization + # float8 quantized linear operation y = torch.ops.torchao.float8_linear_cpu.default( act_mat, act_scale, From 4c791125c66daeec4a9e0b2c7c1a5d9616320df5 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 9 Sep 2025 21:46:21 +0000 Subject: [PATCH 11/15] Refine code --- torchao/csrc/cpu/aten_kernels/float8_linear.cpp | 2 +- torchao/quantization/qat/fake_quantize_config.py | 5 +++++ torchao/quantization/quant_api.py | 5 ++++- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp index e60320dd27..11cbbaf640 100644 --- a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp +++ b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp @@ -56,7 +56,7 @@ float8_linear_prepack_impl( new_scales.unsqueeze_(1); } new_scales = new_scales.to(at::kFloat); - int G = scales.size(1); + int G = is_per_tensor ? 1 : new_scales.size(1); TORCH_CHECK(K % G == 0, "K should be divisible by num_groups"); int group_size = K / G; int block_k = group_size > 128 ? 128 : group_size; diff --git a/torchao/quantization/qat/fake_quantize_config.py b/torchao/quantization/qat/fake_quantize_config.py index dc86aa919f..c5f280803e 100644 --- a/torchao/quantization/qat/fake_quantize_config.py +++ b/torchao/quantization/qat/fake_quantize_config.py @@ -408,6 +408,11 @@ def _infer_fake_quantize_configs( (act_granularity, weight_granularity) = _normalize_granularity( base_config.granularity ) + assert act_granularity == weight_granularity and isinstance( + act_granularity, (PerTensor, PerRow) + ), ( + "Currently only support same granularity for both activations and weights, and only PerTensor or PerRow" + ) act_config = Float8FakeQuantizeConfig( dtype=base_config.activation_dtype, granularity=act_granularity, diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c202cd56e7..e1a5e1bfad 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1753,7 +1753,7 @@ def __post_init__(self): if self.mm_config is None: self.mm_config = Float8MMConfig(use_fast_accum=True) activation_granularity, weight_granularity = _normalize_granularity( - self.granularity, + self.granularity ) if self.packing_format == Float8PackingFormat.PLAIN: assert isinstance(activation_granularity, (PerTensor, PerRow)), ( @@ -1981,6 +1981,9 @@ def _float8_static_activation_float8_weight_transform( weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) + assert activation_granularity == weight_granularity, ( + "Different granularities for activation and weight are not supported" + ) assert isinstance(activation_granularity, PerTensor), ( "Static quantization only supports PerTensor granularity" ) From f28655889af69b65d2f1f4c7dccfbfe27f757a2d Mon Sep 17 00:00:00 2001 From: Xia Weiwen Date: Tue, 9 Sep 2025 07:08:39 -0700 Subject: [PATCH 12/15] Apply suggestions from code review Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../quantize_/workflows/float8/float8_opaque_tensor.py | 3 --- .../quantize_/workflows/float8/float8_packing_format.py | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py index afe666c7e2..39d949f02a 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_opaque_tensor.py @@ -107,9 +107,6 @@ def from_hp( 128, K, ), f"Unsupported block_size: {block_size} for tensor shape {hp_tensor}" - # assert N % 32 == 0, ( - # f"Expecting out_features {N} to be multiple of 32, but got {N}" - # ) assert act_quant_kwargs is not None, ( "Activation quantization args must be provided for Float8OpaqueTensor" ) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py index ce8912be7e..30cf863ac8 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_packing_format.py @@ -16,10 +16,10 @@ class Float8PackingFormat(str, Enum): """ plain means the format that quantized Tensor data lays out elements in Tensor sequentially, - for example: for a Tensor of shape (4, 6): - a_0_0, a_0_1, ..., a_0_5, - ... - a_3_0, a_3_1, ..., a_3_5 + for example, for a Tensor of shape (4, 6): + a_0_0, a_0_1, ..., a_0_5, + ... + a_3_0, a_3_1, ..., a_3_5 """ PLAIN = "plain" From 704e1566b9191d170e2d66171e2e6f06df14a5ed Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Tue, 16 Sep 2025 13:37:14 +0000 Subject: [PATCH 13/15] Refine kernel code --- .../float8/test_float8_opaque_tensor.py | 4 +- torchao/csrc/cpu/aten_kernels/dispatcher.h | 191 -------- .../csrc/cpu/aten_kernels/float8_linear.cpp | 446 ++++++++---------- torchao/csrc/cpu/aten_kernels/utils.h | 111 +++++ 4 files changed, 319 insertions(+), 433 deletions(-) delete mode 100644 torchao/csrc/cpu/aten_kernels/dispatcher.h create mode 100644 torchao/csrc/cpu/aten_kernels/utils.h diff --git a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py index e7f83d3533..3c6a64c53e 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py @@ -134,11 +134,11 @@ def test_dynamic_float8_linear_per_tensor_cpu( @common_utils.parametrize("dtype", [torch.float, torch.bfloat16, torch.half]) @common_utils.parametrize("x_dim", [2, 3]) @common_utils.parametrize("bias", [True, False]) - def test_dynamic_float8_linear_ref_cpu(self, dtype, x_dim, bias): + @common_utils.parametrize("bs", [4, 128]) + def test_dynamic_float8_linear_ref_cpu(self, dtype, x_dim, bias, bs): device = "cpu" # the shape is not supported by cpp kernel, so the ref path will be used. m = ToyLinearModel(120, 120, bias=bias).eval().to(dtype).to(device) - bs = 4 example_inputs = m.example_inputs(batch_size=bs, dtype=dtype, device=device) if x_dim == 3: example_inputs = (example_inputs[0].unsqueeze(0),) diff --git a/torchao/csrc/cpu/aten_kernels/dispatcher.h b/torchao/csrc/cpu/aten_kernels/dispatcher.h deleted file mode 100644 index 81edbfa971..0000000000 --- a/torchao/csrc/cpu/aten_kernels/dispatcher.h +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright (c) Meta Platforms, Inc. and affiliates. -// All rights reserved. -// -// This source code is licensed under the BSD 3-Clause license found in the -// LICENSE file in the root directory of this source tree. - -#pragma once - -#include -#include - -template < - typename IntegralType, - int n, - IntegralType First, - IntegralType... Rest> -struct enumerate_dispatcher_helper { - template - inline static void call( - IntegralType i, - const std::function& comparator, - const Lambda1& function, - const Lambda2& fallback, - Args... args) { - if (comparator(i, First)) - function( - std::integral_constant{}, - std::forward(args)...); - else - enumerate_dispatcher_helper::call( - i, comparator, function, fallback, std::forward(args)...); - } -}; - -template -struct enumerate_dispatcher_helper { - template - inline static void call( - IntegralType i, - const std::function& comparator, - const Lambda1& function, - const Lambda2& fallback, - Args... args) { - if (comparator(i, First)) - function( - std::integral_constant{}, - std::forward(args)...); - else - fallback(i, std::forward(args)...); - } -}; - -// dispatch a list of integers to a lambda function -template -struct enumerate_dispatcher { - template - inline static void call( - IntegralType i, - const Lambda1& function, - const Lambda2& fallback, - Args... args) { - enumerate_dispatcher_helper:: - call( - i, - [&](IntegralType a, IntegralType b) { return a == b; }, - function, - fallback, - std::forward(args)...); - } -}; - -// A helper function that returns the last N-1 items of a tuple as a new tuple -template -auto get_last_n_minus_one_impl(TupleType&& t, std::index_sequence) { - return std::tuple_cat(std::make_tuple(std::get(t))...); -} - -// A function that returns the last N-1 items of a tuple as a new tuple -template -auto get_last_n_minus_one(TupleType&& t) { - // Get the size of the tuple - constexpr auto size = - std::tuple_size::type>::value; - // Check if the size is greater than one - return get_last_n_minus_one_impl( - std::forward(t), std::make_index_sequence{}); -} - -template < - typename TupleType, - std::enable_if_t::value == 1, bool> = true> -auto get_last_n_minus_one(TupleType&& t) { - return std::make_tuple(); -} - -template < - typename IntegralTypesProcessed, - typename IntegralTypesToProcess, - typename Dispatchers> -struct product_dispatcher_helper; - -template -struct product_dispatcher_helper< - std::tuple, - std::tuple<>, - std::tuple<>> { - template - inline static void call( - std::tuple<>, - std::tuple constants, - std::tuple<>, - const Lambda1& function, - const Lambda2& fallback, - Args... args) { - function(constants, std::forward(args)...); - } -}; - -template < - typename... IntegralTypeProcessed, - typename... IntegeralTypeToProcess, - typename... Dispatcher> -struct product_dispatcher_helper< - std::tuple, - std::tuple, - std::tuple> { - template - inline static void call( - std::tuple dispatchers, - std::tuple constants, - std::tuple integrals, - const Lambda1& function, - const Lambda2& fallback, - Args... args) { - std::get<0>(dispatchers) - .call( - std::get<0>(integrals), - [&](auto i, Args... args) { - auto new_dispatchers = get_last_n_minus_one(dispatchers); - auto new_constants = - std::tuple_cat(constants, std::tuple(i)); - auto new_integrals = get_last_n_minus_one(integrals); - product_dispatcher_helper< - decltype(new_constants), - decltype(new_integrals), - decltype(new_dispatchers)>:: - call( - new_dispatchers, - new_constants, - new_integrals, - function, - fallback, - std::forward(args)...); - }, - [&](auto i, Args... args) { - fallback( - std::tuple_cat(constants, integrals), - std::forward(args)...); - }, - std::forward(args)...); - } -}; - -template -struct product_dispatcher; - -// dispatch to a carsian product of a list of integers to a lambda function -template -struct product_dispatcher< - std::tuple, - std::tuple> { - template - inline static void call( - std::tuple integrals, - const Lambda1& function, - const Lambda2& fallback, - Args... args) { - static auto dispatchers = std::tuple{}; - product_dispatcher_helper< - std::tuple<>, - std::tuple, - std::tuple>:: - call( - dispatchers, - std::tuple<>{}, - integrals, - function, - fallback, - std::forward(args)...); - } -}; diff --git a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp index 11cbbaf640..d15808bfac 100644 --- a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp +++ b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp @@ -2,7 +2,12 @@ #include #include #include -#include "dispatcher.h" +#include "utils.h" +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif namespace torchao { @@ -78,9 +83,9 @@ float8_linear_prepack_impl( #if defined(CPU_CAPABILITY_AVX512) if (cpublas_could_pack()) { #ifdef CPUBLAS_BRGEMM_F8F8F32 - constexpr int vnni_size = 4; // for fp8 + constexpr int vnni_size = get_vnni_size(); // for fp8 #else - constexpr int vnni_size = 2; // for float16 + constexpr int vnni_size = get_vnni_size(); // for bfloat16 #endif blocked_weight = at::empty({Nc, Kc, block_k, block_n}, weight.options()); auto weight_ptr = reinterpret_cast(weight_reordered.data_ptr()); @@ -119,94 +124,59 @@ float8_linear_prepack_impl( } #if defined(CPU_CAPABILITY_AVX512) -alignas(64) static uint16_t e4m3_to_16bit[256]; - -template -static void initialize_e4m3_to_16bit_tables() { - // run only once - static bool initialized_16bit = false; - if (!initialized_16bit) { - for (uint8_t u8 = 0; u8 < 256; ++u8) { - auto value = static_cast(c10::bit_cast(u8)); - uint16_t value_bits = c10::bit_cast(value); - e4m3_to_16bit[u8] = value_bits; - if (u8 == 255) { - break; - } - } - initialized_16bit = true; - } -} +// this doesn't handle NaN. +inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { + const __m512i x = _mm512_cvtepu8_epi16(fp8_vec); -template -static void cvt_e4m3_16bit_intrinsic_lut( - const at::Float8_e4m3fn* __restrict__ in, - T* out, - int64_t len) { - for (size_t i = 0; i < len; i += 64) { - __m512i fp8_vec = _mm512_loadu_si512((__m512i*)&in[i]); - __m128i group0 = _mm512_castsi512_si128(fp8_vec); - __m128i group1 = _mm512_extracti32x4_epi32(fp8_vec, 1); - __m128i group2 = _mm512_extracti32x4_epi32(fp8_vec, 2); - __m128i group3 = _mm512_extracti32x4_epi32(fp8_vec, 3); - - __m512i indices0 = _mm512_cvtepu8_epi32(group0); - __m512i indices1 = _mm512_cvtepu8_epi32(group1); - __m512i indices2 = _mm512_cvtepu8_epi32(group2); - __m512i indices3 = _mm512_cvtepu8_epi32(group3); - - // Gather BF16 conversion results from the lookup table. - __m512i bf16_i32_vec0 = _mm512_i32gather_epi32(indices0, e4m3_to_16bit, 2); - __m512i bf16_i32_vec1 = _mm512_i32gather_epi32(indices1, e4m3_to_16bit, 2); - __m512i bf16_i32_vec2 = _mm512_i32gather_epi32(indices2, e4m3_to_16bit, 2); - __m512i bf16_i32_vec3 = _mm512_i32gather_epi32(indices3, e4m3_to_16bit, 2); - - // Helper lambda: Convert 16 32-bit ints (in a __m512i) to 16 16-bit ints. - auto convert_32_to_16 = [](__m512i vec) -> __m256i { - return _mm512_cvtepi32_epi16(vec); - }; - - __m256i bf16_i16_vec0 = convert_32_to_16(bf16_i32_vec0); - __m256i bf16_i16_vec1 = convert_32_to_16(bf16_i32_vec1); - __m256i bf16_i16_vec2 = convert_32_to_16(bf16_i32_vec2); - __m256i bf16_i16_vec3 = convert_32_to_16(bf16_i32_vec3); - - _mm256_storeu_si256((__m256i*)(out + i + 0), bf16_i16_vec0); - _mm256_storeu_si256((__m256i*)(out + i + 16), bf16_i16_vec1); - _mm256_storeu_si256((__m256i*)(out + i + 32), bf16_i16_vec2); - _mm256_storeu_si256((__m256i*)(out + i + 48), bf16_i16_vec3); - } -} + const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4); + const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3); + const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7); + const __m512i nonsign = _mm512_or_si512(exp, mant); -static void _convert_B_to_bf16( - const at::Float8_e4m3fn* __restrict__ B, - at::BFloat16* dqB, - int64_t len) { - initialize_e4m3_to_16bit_tables(); - int tail = len % 64; - cvt_e4m3_16bit_intrinsic_lut(B, dqB, len - tail); - for (int i = len - tail; i < len; ++i) { - dqB[i] = (at::BFloat16)B[i]; - } + const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8); + const __m512i combined = _mm512_or_si512(nonsign, sign); + + const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512()); + return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined); } -static void _convert_A_to_bf16( - const at::Float8_e4m3fn* __restrict__ A, - at::BFloat16* dqA, - int64_t M, - int64_t K, - int64_t lda) { - initialize_e4m3_to_16bit_tables(); - for (int m = 0; m < M; ++m) { - int tail = K % 64; - int body = K - tail; - cvt_e4m3_16bit_intrinsic_lut(A + m * lda, dqA + m * K, body); - for (int k = body; k < K; ++k) { - dqA[m * K + k] = (at::BFloat16)A[m * lda + k]; +static void cvt_f8e4m3_to_bf16( + const at::Float8_e4m3fn* __restrict__ in, + at::BFloat16* out, + int64_t rows, + int64_t cols, + int64_t stride) { + if (stride == cols) { + // A contiguous buffer + size_t len = rows * cols; + size_t i = 0; + for (; i < len; i += 32) { + __m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[i]); + __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec); + _mm512_storeu_si512((__m512i*)(out + i), (__m512i)bf16_vec); + } + for (; i < len; ++i) { + out[i] = (at::BFloat16)in[i]; + } + } else { + // Non-contiguous. Access each row with stride + TORCH_CHECK(stride > cols); + for (int r = 0; r < rows; ++r) { + size_t i = 0; + size_t vec_len = cols / 32 * 32; + for (; i < vec_len; i += 32) { + __m256i fp8_vec = _mm256_loadu_si256((__m256i*)&in[r * stride + i]); + __m512bh bf16_vec = cvt_e4m3_bf16_intrinsic_no_nan(fp8_vec); + _mm512_storeu_si512((__m512i*)(out + r * cols + i), (__m512i)bf16_vec); + } + for (; i < cols; ++i) { + out[r * cols + i] = (at::BFloat16)in[r * stride + i]; + } } } } + // accumulate and store result to buffer // if act/wei are per_group quantized, apply scales template @@ -227,9 +197,9 @@ static void _store_result( a_scale = *(scale_a + m * ldsa); va_scale = _mm512_set1_ps(a_scale); } - int n = 0; -#pragma GCC unroll 2 - for (; n < N; n += 16) { + constexpr int N_UNROLL = N / 16; + c10::ForcedUnroll{}([&](auto i) { + constexpr int n = i * 16; __m512 vc_f = _mm512_loadu_ps(input + m * ldi + n); if constexpr (act_quant_mode == PER_GROUP) { vc_f = _mm512_mul_ps(vc_f, va_scale); @@ -244,8 +214,9 @@ static void _store_result( } else { _mm512_storeu_ps(output + m * ldo + n, vc_f); } - } - for (; n < N; ++n) { + }); + constexpr int tail_start = N / 16 * 16; + for (int n = tail_start; n < N; ++n) { float dq_val = input[m * ldi + n]; if constexpr (act_quant_mode == PER_GROUP) { dq_val = dq_val * a_scale; @@ -263,29 +234,119 @@ static void _store_result( } } -#else -static void _convert_B_to_bf16( - const at::Float8_e4m3fn* B, - at::BFloat16* dqB, - int64_t len) { - for (int i = 0; i < len; ++i) { - dqB[i] = (at::BFloat16)B[i]; +// Store result to output buffer with dtype conversion +// If act/wei are per_row or per_tensor quantized, apply scales +// If bias is not null, add bias +template +inline void store_out( + const float* y_buf, + out_dtype* c_ptr, + int64_t M, + int64_t lda, + const float* scales_a, + const float* scales_b, + const float* bias) { + float a_scale = 1.0, b_scale = 1.0; + __m512 va_scale, vb_scale; + if constexpr (act_quant_mode == PER_TENSOR) { + a_scale = *scales_a; + } + if constexpr (wei_quant_mode == PER_TENSOR) { + b_scale = *scales_b; + vb_scale = _mm512_set1_ps(b_scale); } + for (int i = 0; i < M; ++i) { + if constexpr (act_quant_mode == PER_ROW) { + a_scale = *(scales_a + i); + } + if constexpr (act_quant_mode != PER_GROUP) { + va_scale = _mm512_set1_ps(a_scale); + } + constexpr int N_UNROLL = N / 16; + c10::ForcedUnroll{}([&](auto idx) { + constexpr int j = idx * 16; + __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); + __m512 bias_vec = bias ? _mm512_loadu_ps(bias + j) : _mm512_setzero_ps(); + if constexpr (act_quant_mode != PER_GROUP) { + y_vec = _mm512_mul_ps(y_vec, va_scale); + } + if constexpr (wei_quant_mode == PER_ROW) { + vb_scale = _mm512_loadu_ps(scales_b + j); + } + if constexpr (wei_quant_mode != PER_GROUP) { + y_vec = _mm512_mul_ps(y_vec, vb_scale); + } + y_vec = _mm512_add_ps(y_vec, bias_vec); + if constexpr (std::is_same::value) { + _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); + } else if constexpr (std::is_same::value) { + __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); + } else if constexpr (std::is_same::value) { + __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); + _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); + } else { + TORCH_CHECK(false, "Unsupported output dtype"); + } + }); + constexpr int tail_start = N / 16 * 16; + for (int j = tail_start; j < N; ++j) { + if constexpr (wei_quant_mode == PER_ROW) { + b_scale = scales_b[j]; + } + c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); + } + } // for M } -static void _convert_A_to_bf16( - const at::Float8_e4m3fn* __restrict__ A, - at::BFloat16* dqA, - int64_t M, - int64_t K, - int64_t lda) { - for (int m = 0; m < M; ++m) { - for (int k = 0; k < K; ++k) { - dqA[m * K + k] = (at::BFloat16)A[m * lda + k]; +#else // no AVX512 + +static void cvt_f8e4m3_to_bf16( + const at::Float8_e4m3fn* __restrict__ in, + at::BFloat16* out, + int64_t rows, + int64_t cols, + int64_t stride) { + for (int r = 0; r < rows; ++r) { + for (int c = 0; c < cols; ++c) { + out[r * cols + c] = (at::BFloat16)in[r * stride + c]; } } } -#endif + +// Store result to output buffer with dtype conversion +// If act/wei are per_row or per_tensor quantized, apply scales +// If bias is not null, add bias +template +inline void store_out( + const float* y_buf, + out_dtype* c_ptr, + int64_t M, + int64_t lda, + const float* scales_a, + const float* scales_b, + const float* bias) { + float a_scale = 1.0, b_scale = 1.0; + if constexpr (act_quant_mode == PER_TENSOR) { + a_scale = *scales_a; + } + if constexpr (wei_quant_mode == PER_TENSOR) { + b_scale = *scales_b; + } + for (int i = 0; i < M; ++i) { + if constexpr (act_quant_mode == PER_ROW) { + a_scale = *(scales_a + i); + } + for (int j = 0; j < N; ++j) { + if constexpr (wei_quant_mode == PER_ROW) { + b_scale = scales_b[j]; + } + c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); + } + } // for M +} + +#endif // CPU_CAPABILITY_AVX512 template void _micro_gemm( @@ -305,9 +366,9 @@ void _micro_gemm( // Finally accumulate and store results #ifndef CPUBLAS_BRGEMM_F8F8F32 at::BFloat16 dqB[K * N]; - _convert_B_to_bf16(B, dqB, K * N); + cvt_f8e4m3_to_bf16(B, dqB, K, N, N); at::BFloat16 dqA[M * K]; - _convert_A_to_bf16(A, dqA, M, K, lda); + cvt_f8e4m3_to_bf16(A, dqA, M, K, lda); #endif #if defined(CPU_CAPABILITY_AVX512) if constexpr (cpublas_can_pack) { @@ -339,8 +400,8 @@ void _micro_gemm( C_f32, true /* is_vnni */); #endif - _mm_prefetch(B + N * K, _MM_HINT_T0); - _mm_prefetch(A + K, _MM_HINT_T0); + _mm_prefetch(B + N * (K + 128), _MM_HINT_T0); + _mm_prefetch(A + K + 128, _MM_HINT_T0); _store_result( C, C_f32, @@ -375,77 +436,6 @@ void _micro_gemm( } } -// Store result to output buffer with dtype conversion -// If act/wei are per_row or per_tensor quantized, apply scales -// If bias is not null, add bias -template -inline void store_out( - const float* y_buf, - out_dtype* c_ptr, - int64_t M, - int64_t lda, - const float* scales_a, - const float* scales_b, - const float* bias) { - float a_scale = 1.0, b_scale = 1.0; -#if defined(CPU_CAPABILITY_AVX512) - __m512 va_scale, vb_scale; -#endif - if constexpr (act_quant_mode == PER_TENSOR) { - a_scale = *scales_a; - } - if constexpr (wei_quant_mode == PER_TENSOR) { - b_scale = *scales_b; -#if defined(CPU_CAPABILITY_AVX512) - vb_scale = _mm512_set1_ps(b_scale); -#endif - } - for (int i = 0; i < M; ++i) { - if constexpr (act_quant_mode == PER_ROW) { - a_scale = *(scales_a + i); - } - int j = 0; -#if defined(CPU_CAPABILITY_AVX512) - if constexpr (act_quant_mode != PER_GROUP) { - va_scale = _mm512_set1_ps(a_scale); - } -#pragma GCC unroll 2 - for (; j < N; j += 16) { - __m512 y_vec = _mm512_loadu_ps(y_buf + i * N + j); - __m512 bias_vec = bias ? _mm512_loadu_ps(bias + j) : _mm512_setzero_ps(); - if constexpr (act_quant_mode != PER_GROUP) { - y_vec = _mm512_mul_ps(y_vec, va_scale); - } - if constexpr (wei_quant_mode == PER_ROW) { - vb_scale = _mm512_loadu_ps(scales_b + j); - } - if constexpr (wei_quant_mode != PER_GROUP) { - y_vec = _mm512_mul_ps(y_vec, vb_scale); - } - y_vec = _mm512_add_ps(y_vec, bias_vec); - if constexpr (std::is_same::value) { - _mm512_storeu_ps(c_ptr + i * lda + j, y_vec); - } else if constexpr (std::is_same::value) { - __m256i y_bf16_vec = at::vec::cvtfp32_bf16(y_vec); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_bf16_vec); - } else if constexpr (std::is_same::value) { - __m256i y_fp16_vec = at::vec::cvtfp32_fp16(y_vec); - _mm256_storeu_si256(reinterpret_cast<__m256i*>(c_ptr + i * lda + j), y_fp16_vec); - } else { - TORCH_CHECK(false, "Unsupported output dtype"); - } - } -#else - for (; j < N; ++j) { - if constexpr (wei_quant_mode == PER_ROW) { - b_scale = scales_b[j]; - } - c_ptr[i * lda + j] = static_cast(y_buf[i * N + j] * a_scale * b_scale); - } -#endif - } // for M -} - template void _float8_linear_impl( const at::Tensor& input, @@ -469,20 +459,8 @@ void _float8_linear_impl( TORCH_CHECK(weight.size(3) == block_n, "Float8 linear: unexpected weight shape"); int64_t N = Nc * block_n; TORCH_CHECK(K == Kc * block_k, "Float8 linear: weight and input shapes mismatch"); - int64_t block_m = [&]() -> long { - if (M <= 48) { - return M; - } else if (M < 64) { - return 32; - } else if (M < 96) { - return 64; - } else { - return 128; - } - }(); - int64_t Mc = (M + block_m - 1) / block_m; - bool parallel_on_M = M > 128; - int64_t num_blocks = parallel_on_M ? Mc * Nc : Nc; + auto [parallel_on_M, block_m, Mc, Mc_parallel] = get_m_blocking(M); + int64_t num_parallel_blocks = Mc_parallel * Nc; // scales shape = [Nc, G, block_n] int64_t num_groups = wei_quant_mode == PER_TENSOR ? 1 : weight_scales.size(1); @@ -501,30 +479,35 @@ void _float8_linear_impl( out_dtype* c_ptr = output.data_ptr(); const float* bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; - at::parallel_for(0, num_blocks, 1, [&](int64_t begin, int64_t end) { - auto y_buf = new (std::align_val_t(8)) float[block_m * block_n]; + int64_t block_size = block_m * block_n; + int64_t num_thread = at::get_num_threads(); + at::Tensor y_buffer = at::empty({num_thread, block_size}, output.options().dtype(at::kFloat)); + + at::parallel_for(0, num_parallel_blocks, 1, [&](int64_t begin, int64_t end) { + float* y_buf = y_buffer.data_ptr() + at::get_thread_num() * block_size; + int64_t mc = 0, nc = 0; + at::native::data_index_init(begin, mc, Mc_parallel, nc, Nc); for (const auto i : c10::irange(begin, end)) { - int64_t mc = parallel_on_M ? i / Nc : 0; - int64_t nc = parallel_on_M ? i % Nc : i; + (void)i; // Suppress unused variable int64_t mc_end = parallel_on_M ? mc + 1 : Mc; for (int mci = mc; mci < mc_end; ++mci) { int64_t m_size = mci * block_m + block_m > M ? M - mci * block_m : block_m; - memset(y_buf, 0, sizeof(float) * m_size * block_n); + zero_buffer(y_buf, m_size * block_n); for (int kci = 0; kci < Kc; ++kci) { auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group; auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n; _micro_gemm( - y_buf /*C*/, - a_ptr + mci * block_m * K + kci * block_k /*A*/, - scales_a /*scales_a*/, - b_ptr + (nc * Kc + kci) * block_n * block_k /*B*/, - scales_b /*scales_b*/, - m_size /*M*/, - block_k /*K*/, - K /*lda*/, - block_n /*ldc*/, - ldsa /*ldsa*/); + /* C */ y_buf, + /* A */ a_ptr + mci * block_m * K + kci * block_k, + /* A scales */ scales_a, + /* B */ b_ptr + (nc * Kc + kci) * block_n * block_k, + /* B scales */ scales_b, + /* M */ m_size, + /* K */ block_k, + /* lda */ K, + /* ldc */ block_n, + /* ldsa */ ldsa); } // store y_buf to output with dtype conversion auto scales_a = act_quant_mode == PER_TENSOR ? a_scales_ptr : @@ -541,8 +524,8 @@ void _float8_linear_impl( scales_b, bias_data); } + at::native::data_index_step(mc, Mc_parallel, nc, Nc); } - delete[] y_buf; if constexpr (cpublas_can_pack) { at::native::cpublas::brgemm_release(); } @@ -567,7 +550,10 @@ at::Tensor float8_linear_impl( if (weight.dim() == 2) { TORCH_CHECK(act_quant_mode != PER_GROUP && wei_quant_mode != PER_GROUP, "FP8 linear: Per-group quantization is not supported in the fallback path"); - auto y_fp32 = at::linear(input.to(at::kFloat).mul(input_scales), weight.to(at::kFloat).mul(weight_scales), bias); + auto y_fp32 = at::linear( + input.to(at::kFloat).mul_(input_scales), + weight.to(at::kFloat).mul_(weight_scales), + bias); return y_fp32.to(output_dtype); } @@ -576,35 +562,15 @@ at::Tensor float8_linear_impl( out_sizes.back() = N; auto output = at::empty(out_sizes, input.options().dtype(output_dtype)); - product_dispatcher< - std::tuple< - /*output_dtype*/ at::ScalarType, - /*cpublas_can_pack*/ bool, - /*act_quant_mode*/ int, - /*wei_quant_mode*/ int>, - std::tuple< - enumerate_dispatcher, - enumerate_dispatcher, - enumerate_dispatcher, - enumerate_dispatcher>>:: - call( - std::make_tuple(output_dtype, cpublas_can_pack, act_quant_mode, wei_quant_mode), - [&](auto tuple) { - constexpr auto o_dtype = std::get<0>(tuple); - using out_dtype = typename c10::impl::ScalarTypeToCPPType::type; - constexpr bool cpublas_can_pack_v = std::get<1>(tuple); - constexpr int act_quant_mode_v = std::get<2>(tuple); - constexpr int wei_quant_mode_v = std::get<3>(tuple); - _float8_linear_impl( - input, - input_scales, - weight, - weight_scales, - bias, - output); - }, - [](auto tuple) { TORCH_CHECK(false, "Not implemented for this configuration"); }); - + AT_DISPATCH_LINEAR_KERNEL(output_dtype, cpublas_can_pack, act_quant_mode, wei_quant_mode, [&](){ + _float8_linear_impl( + input, + input_scales, + weight, + weight_scales, + bias, + output); + }); return output; } diff --git a/torchao/csrc/cpu/aten_kernels/utils.h b/torchao/csrc/cpu/aten_kernels/utils.h new file mode 100644 index 0000000000..0a238a316e --- /dev/null +++ b/torchao/csrc/cpu/aten_kernels/utils.h @@ -0,0 +1,111 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD 3-Clause license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +int64_t get_m_block(int64_t M) { + if (M <= 48) { + return M; + } else if (M < 64) { + return 32; + } else if (M < 96) { + return 64; + } else { + return 128; + } +} + +std::tuple +get_m_blocking(int64_t M) { + bool parallel_on_M = M > 128; + int64_t block_m = get_m_block(M); + int64_t Mc = (M + block_m - 1) / block_m; + int64_t Mc_parallel = parallel_on_M ? Mc : 1; + return std::make_tuple(parallel_on_M, block_m, Mc, Mc_parallel); +} + +#if defined(CPU_CAPABILITY_AVX512) +template +void zero_buffer(T* data, int64_t size) { + const int32_t vec_size = at::vec::Vectorized::size(); + auto zero_vec = at::vec::Vectorized(0); + int64_t d = 0; + for (; d < size - (size % vec_size); d += vec_size) { + zero_vec.store(data + d); + } + if (d < size) { + zero_vec.store(data + d, size - d); + } +} +#else +template +void zero_buffer(T* data, int64_t size) { + memset(data, 0, sizeof(T) * size); +} +#endif + +template struct vnni_traits; +template <> struct vnni_traits { static constexpr int size = 2; }; +template <> struct vnni_traits { static constexpr int size = 2; }; +template <> struct vnni_traits { static constexpr int size = 4; }; +template <> struct vnni_traits { static constexpr int size = 4; }; + +template constexpr int get_vnni_size() { return vnni_traits::size; } + + +// Utilities for dispatch +#define AT_DISPATCH_OUT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Float, out_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::BFloat16, out_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Half, out_t, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ENUM(VALUE, TYPE, HINT, ...) \ + case VALUE: { \ + constexpr TYPE HINT = VALUE; \ + __VA_ARGS__; \ + break; \ + } + +#define AT_DISPATCH_BOOL(VALUE, NAME, HINT, ...) \ + [&]() { \ + switch (VALUE) { \ + AT_DISPATCH_CASE_ENUM(true, bool, HINT, __VA_ARGS__) \ + AT_DISPATCH_CASE_ENUM(false, bool, HINT, __VA_ARGS__) \ + } \ + }() + +#define AT_DISPATCH_QUANT_MODE(MODE, NAME, HINT, ...) \ + [&]() { \ + switch (MODE) { \ + AT_DISPATCH_CASE_ENUM(PER_TENSOR, int, HINT, __VA_ARGS__) \ + AT_DISPATCH_CASE_ENUM(PER_ROW, int, HINT, __VA_ARGS__) \ + AT_DISPATCH_CASE_ENUM(PER_GROUP, int, HINT, __VA_ARGS__) \ + } \ + }() + +#define AT_DISPATCH_LINEAR_KERNEL(OUT_DTYPE, CAN_PACK, A_QUANT_MODE, B_QUANT_MODE, ...) \ + AT_DISPATCH_BOOL( \ + CAN_PACK, "cpublas_can_pack", can_pack, \ + AT_DISPATCH_QUANT_MODE( \ + A_QUANT_MODE, "act_quant_mode", a_quant_mode, \ + AT_DISPATCH_QUANT_MODE( \ + B_QUANT_MODE, "wei_quant_mode", b_quant_mode, \ + AT_DISPATCH_OUT_TYPES( \ + OUT_DTYPE, "out_dtype", __VA_ARGS__ \ + ) \ + ) \ + ) \ + ) From 8beeb03df4cc0b770298a4214400622d368e4f64 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Wed, 17 Sep 2025 12:26:24 +0000 Subject: [PATCH 14/15] Allocate buffer outside micro gemm kernel --- .../csrc/cpu/aten_kernels/float8_linear.cpp | 77 ++++++++++++------- 1 file changed, 48 insertions(+), 29 deletions(-) diff --git a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp index d15808bfac..0301568477 100644 --- a/torchao/csrc/cpu/aten_kernels/float8_linear.cpp +++ b/torchao/csrc/cpu/aten_kernels/float8_linear.cpp @@ -180,7 +180,7 @@ static void cvt_f8e4m3_to_bf16( // accumulate and store result to buffer // if act/wei are per_group quantized, apply scales template -static void _store_result( +static void _accumulate_result( float* __restrict__ output, const float* __restrict__ input, const float* __restrict__ scale_a, @@ -359,20 +359,16 @@ void _micro_gemm( int64_t K, int64_t lda, int64_t ldc, - int64_t ldsa) { + int64_t ldsa, + float* ukernel_buf, + at::BFloat16* dqA_buf, + at::BFloat16* dqB_buf) { // If FP8 brgemm is not available, convert A/B to bf16 for computation // Compute GEMM fp8 * fp8 -> fp32 (or bf16 * bf16 -> fp32) // If per_group quant, apply scales. Otherwise, don't apply scales here // Finally accumulate and store results -#ifndef CPUBLAS_BRGEMM_F8F8F32 - at::BFloat16 dqB[K * N]; - cvt_f8e4m3_to_bf16(B, dqB, K, N, N); - at::BFloat16 dqA[M * K]; - cvt_f8e4m3_to_bf16(A, dqA, M, K, lda); -#endif #if defined(CPU_CAPABILITY_AVX512) if constexpr (cpublas_can_pack) { - float C_f32[M * N]; #ifdef CPUBLAS_BRGEMM_F8F8F32 at::native::cpublas::brgemm( M, @@ -384,9 +380,11 @@ void _micro_gemm( false /* add_C */, A, B, - C_f32, + ukernel_buf, true /* is_vnni */); #else + cvt_f8e4m3_to_bf16(A, dqA_buf, M, K, lda); + cvt_f8e4m3_to_bf16(B, dqB_buf, K, N, N); at::native::cpublas::brgemm( M, N, @@ -395,16 +393,16 @@ void _micro_gemm( N /*ldb*/, N /*ldc*/, false /* add_C */, - dqA, - dqB, - C_f32, + dqA_buf, + dqB_buf, + ukernel_buf, true /* is_vnni */); #endif _mm_prefetch(B + N * (K + 128), _MM_HINT_T0); _mm_prefetch(A + K + 128, _MM_HINT_T0); - _store_result( + _accumulate_result( C, - C_f32, + ukernel_buf, scales_a, scales_b, M, @@ -418,11 +416,7 @@ void _micro_gemm( for (int64_t j = 0; j < N; ++j) { float sum = 0; for (int64_t k = 0; k < K; ++k) { -#ifdef CPUBLAS_BRGEMM_F8F8F32 sum += ((float)A[i * lda + k] * (float)B[k * N + j]); -#else - sum += ((float)dqA[i * K + k] * dqB[k * N + j]); -#endif } if constexpr (act_quant_mode == PER_GROUP) { sum *= scales_a[i * ldsa]; @@ -482,9 +476,31 @@ void _float8_linear_impl( int64_t block_size = block_m * block_n; int64_t num_thread = at::get_num_threads(); at::Tensor y_buffer = at::empty({num_thread, block_size}, output.options().dtype(at::kFloat)); + // Create buffer for brgemm output and dqA/dqB (optional) +#if defined(CPU_CAPABILITY_AVX512) + // buffer for brgemm output in float32 + int64_t buffer_size = block_size * 2; // float32 = bfloat16 * 2 +#ifndef CPUBLAS_BRGEMM_F8F8F32 + // buffers for dqA & dqB in bf16 + buffer_size += (block_k * block_n + block_m * block_k); +#endif + at::Tensor micro_gemm_buffer = at::empty({num_thread, buffer_size}, output.options().dtype(at::kBFloat16)); +#endif at::parallel_for(0, num_parallel_blocks, 1, [&](int64_t begin, int64_t end) { + // Get the address of pre-allocated buffers float* y_buf = y_buffer.data_ptr() + at::get_thread_num() * block_size; + at::BFloat16 *dqA_buffer = nullptr, *dqB_buffer = nullptr; + float* ukernel_buf = nullptr; +#if defined(CPU_CAPABILITY_AVX512) + at::BFloat16* micro_gemm_buf = micro_gemm_buffer.data_ptr() + at::get_thread_num() * buffer_size; + ukernel_buf = reinterpret_cast(micro_gemm_buf); +#ifndef CPUBLAS_BRGEMM_F8F8F32 + dqA_buffer = micro_gemm_buf; + dqB_buffer = micro_gemm_buf + block_m * block_k; + ukernel_buf = reinterpret_cast(micro_gemm_buf + block_m * block_k + block_k * block_n); +#endif +#endif int64_t mc = 0, nc = 0; at::native::data_index_init(begin, mc, Mc_parallel, nc, Nc); for (const auto i : c10::irange(begin, end)) { @@ -498,16 +514,19 @@ void _float8_linear_impl( auto scales_a = a_scales_ptr + mci * block_m * num_groups + kci / block_per_group; auto scales_b = b_scales_ptr + nc * block_n * num_groups + kci / block_per_group * block_n; _micro_gemm( - /* C */ y_buf, - /* A */ a_ptr + mci * block_m * K + kci * block_k, - /* A scales */ scales_a, - /* B */ b_ptr + (nc * Kc + kci) * block_n * block_k, - /* B scales */ scales_b, - /* M */ m_size, - /* K */ block_k, - /* lda */ K, - /* ldc */ block_n, - /* ldsa */ ldsa); + /* C */ y_buf, + /* A */ a_ptr + mci * block_m * K + kci * block_k, + /* scales_a */ scales_a, + /* B */ b_ptr + (nc * Kc + kci) * block_n * block_k, + /* scales_b */ scales_b, + /* M */ m_size, + /* K */ block_k, + /* lda */ K, + /* ldc */ block_n, + /* ldsa */ ldsa, + /* ukernel_buf */ ukernel_buf, + /* dqA_buf */ dqA_buffer, + /* dqB_buf */ dqB_buffer); } // store y_buf to output with dtype conversion auto scales_a = act_quant_mode == PER_TENSOR ? a_scales_ptr : From 5e75764e5485b80f50a4fccb33dfc65cc921a973 Mon Sep 17 00:00:00 2001 From: "Xia, Weiwen" Date: Fri, 19 Sep 2025 13:43:49 +0000 Subject: [PATCH 15/15] packing_format --> float8_packing_format --- .../workflows/float8/test_float8_opaque_tensor.py | 2 +- torchao/quantization/quant_api.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py index 3c6a64c53e..99acd5ed82 100644 --- a/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py +++ b/test/quantization/quantize_/workflows/float8/test_float8_opaque_tensor.py @@ -29,7 +29,7 @@ def get_config(granularity): return Float8DynamicActivationFloat8WeightConfig( activation_dtype=torch.float8_e4m3fn, granularity=granularity, - packing_format="opaque", + float8_packing_format="opaque", ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 3ef3f29890..e57fa56f1d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1770,7 +1770,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True version: int = 2 - packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN + float8_packing_format: Float8PackingFormat = Float8PackingFormat.PLAIN def __post_init__(self): torch._C._log_api_usage_once( @@ -1781,7 +1781,7 @@ def __post_init__(self): activation_granularity, weight_granularity = _normalize_granularity( self.granularity ) - if self.packing_format == Float8PackingFormat.PLAIN: + if self.float8_packing_format == Float8PackingFormat.PLAIN: assert isinstance(activation_granularity, (PerTensor, PerRow)), ( f"Unsupported granularity {activation_granularity}, only PerTensor or PerRow are supported." ) @@ -1809,7 +1809,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_value_lb = config.activation_value_lb activation_value_ub = config.activation_value_ub kernel_preference = config.kernel_preference - packing_format = config.packing_format + float8_packing_format = config.float8_packing_format # Ensure works on device activation_granularity, weight_granularity = granularity @@ -1861,7 +1861,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): kernel_preference=kernel_preference, ) - if packing_format == Float8PackingFormat.PLAIN: + if float8_packing_format == Float8PackingFormat.PLAIN: quantized_weight = Float8Tensor.from_hp( weight, float8_dtype=weight_dtype, @@ -1870,7 +1870,7 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): kernel_preference=kernel_preference, act_quant_kwargs=act_quant_kwargs, ) - elif packing_format == Float8PackingFormat.OPAQUE: + elif float8_packing_format == Float8PackingFormat.OPAQUE: block_size = get_block_size(weight.shape, weight_granularity) quantized_weight = Float8OpaqueTensor.from_hp( weight, @@ -1878,7 +1878,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): act_quant_kwargs=act_quant_kwargs, ) else: - raise ValueError(f"Unsupported float8 packing format: {packing_format}") + raise ValueError( + f"Unsupported float8 packing format: {float8_packing_format}" + ) return quantized_weight