diff --git a/CHANGELOG.md b/CHANGELOG.md index 6dd06195c9..f21795012d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj * Added benchmarking support for tile engine GEMM Multi D. * Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands. * Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM. +* Added tensor-wise quantization for CK_TILE GEMM ### Optimized diff --git a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp index 83542e76f1..409bb173a1 100644 --- a/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp +++ b/example/ck_tile/17_grouped_gemm/quant_grouped_gemm.cpp @@ -13,7 +13,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_group_quant.hpp" +#include "ck_tile/ops/gemm_quant.hpp" #include "ck_tile/host.hpp" #include "quant_grouped_gemm.hpp" @@ -65,15 +65,15 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s, constexpr auto memory_operation = memory_operation_.value; constexpr bool transpose_c = false; - using QuantGemmProblem = ck_tile::GemmRowColQuantPipelineProblem; + using QuantGemmProblem = ck_tile::GemmRowColTensorQuantPipelineProblem; using GemmPipeline = typename PipelineTypeTraits< GemmConfig::Pipeline>::template GemmPipeline; diff --git a/example/ck_tile/38_block_scale_gemm/README.md b/example/ck_tile/38_block_scale_gemm/README.md index 9acc4f9bfc..9b2610813c 100644 --- a/example/ck_tile/38_block_scale_gemm/README.md +++ b/example/ck_tile/38_block_scale_gemm/README.md @@ -5,6 +5,7 @@ This folder contains examples of quant GEMMs using the ck_tile tile-programming - AQuant kernel with blocks of A matrix sharing scales: custom GEMM pipeline - BQuant kernel with blocks of B matrix sharing scales: custom GEMM pipeline - Row and Column-wise scaled: scaling implemented in Epilogue +- Tensor-wise scaled: scaling implemented in Epilogue ## build ``` @@ -14,7 +15,6 @@ mkdir build && cd build ../script/cmake-ck-dev.sh ../ # Compile the quant kernels make tile_example_gemm_quant_basic -j -make tile_example_gemm_bquant_basic -j ``` This will result in an executable `build/bin/tile_example_gemm_quant_basic` @@ -37,7 +37,7 @@ args: -warmup number of iterations before benchmark the kernel (default:10) -repeat number of iterations to benchmark the kernel (default:100) -timer gpu:gpu timer, cpu:cpu timer (default:gpu) - -quant_mode Which quant method to use (aquant, rowcol) + -quant_mode Which quant method to use (aquant, bquant, tensor, rowcol) ``` User need to select correct mapping of config for each quant mode: diff --git a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp index 79c6cca6cb..91f799f194 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_quant_basic.cpp @@ -66,19 +66,21 @@ float gemm_calc_quant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::str constexpr auto tail_number_v = tail_number_.value; constexpr bool transpose_c = false; + // row-col and tensor quants use the regular pipeline, A/B quants use their own using PipelineProblem = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant, - ck_tile::GemmRowColQuantPipelineProblem, + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant, + ck_tile::GemmRowColTensorQuantPipelineProblem, std::conditional_t>>; using GemmPipeline = std::conditional_t< - QuantMode == ck_tile::QuantType::RowColQuant, + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant, ck_tile::GemmPipelineAgBgCrCompV3, std::conditional_t, @@ -241,10 +244,18 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } + else if(quant_mode == "tensor") + { + return run_gemm_example_prec_type, + TypeConfig, + 128, + ck_tile::QuantType::TensorQuant>( + a_layout, b_layout, argc, argv); + } else { throw std::runtime_error( - "Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'"); + "Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'"); } } else if(data_type == "bf8") @@ -276,10 +287,18 @@ int run_gemm_example(int argc, char* argv[]) ck_tile::QuantType::RowColQuant>( a_layout, b_layout, argc, argv); } + else if(quant_mode == "tensor") + { + return run_gemm_example_prec_type, + TypeConfig, + 128, + ck_tile::QuantType::TensorQuant>( + a_layout, b_layout, argc, argv); + } else { throw std::runtime_error( - "Unsupported quantization mode! Use 'aquant', 'bquant' or 'rowcol'"); + "Unsupported quantization mode! Use 'aquant', 'bquant', 'tensor' or 'rowcol'"); } } else if(data_type == "i4fp8") diff --git a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp index ccf07460fa..e5313d8aaf 100644 --- a/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp +++ b/example/ck_tile/38_block_scale_gemm/gemm_utils.hpp @@ -9,7 +9,7 @@ #include "ck_tile/host/kernel_launch.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_group_quant.hpp" +#include "ck_tile/ops/gemm_quant.hpp" template constexpr ck_tile::index_t get_k_warp_tile() @@ -241,7 +241,7 @@ auto create_args(int argc, char* argv[]) .insert("init", "0", "0:random, 1:linear, 2:constant(1)") .insert("flush_cache", "true", "flush cache before running the kernel, defaults to true") .insert("rotating_count", "1", "rotating count, defaults to 1") - .insert("quant_mode", "aquant", "Choose aquant (default), bquant or rowcol"); + .insert("quant_mode", "aquant", "Choose aquant (default), bquant, tensor or rowcol"); bool result = arg_parser.parse(argc, argv); return std::make_tuple(result, arg_parser); diff --git a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc index 0f45811ff3..8e9456e973 100644 --- a/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc +++ b/example/ck_tile/38_block_scale_gemm/run_gemm_quant_example.inc @@ -119,11 +119,7 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, } std::cout << " Acc_Type = " << DataTypeTraits::name << " C_Type = " << DataTypeTraits::name - << " QuantMode = " - << (QuantMode == ck_tile::QuantType::AQuantGrouped - ? "AQuantGrouped" - : (QuantMode == ck_tile::QuantType::BQuantGrouped ? "BQuantGrouped" - : "RowColQuant")) + << " QuantMode = " << quant_type_to_string(QuantMode) << " PreshuffleQuant = " << (GemmConfig::PreshuffleQuant ? "true" : "false") << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; @@ -183,10 +179,11 @@ int run_gemm_example_with_layouts(int argc, AQK = 0; // No A quantization BQK = K / QuantGroupSize; // Group quantization: BQK = K / GroupSize } - else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { - AQK = 1; // Row quantization: tensor shape [M, 1] - BQK = N; // Column quantization: tensor shape [1, N] + AQK = 1; // Row quantization: tensor shape [M, 1] or [1] + BQK = 1; // Column quantization: tensor shape [1, N] or [1] } else { @@ -227,6 +224,11 @@ int run_gemm_example_with_layouts(int argc, stride_AQ = ck_tile::get_default_stride(M, 1, stride_AQ, is_row_major(aq_layout)); stride_BQ = ck_tile::get_default_stride(1, N, stride_BQ, is_row_major(bq_layout)); } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + stride_AQ = 1; // Tensor quantization: tensor shape [1] + stride_BQ = 1; // Tensor quantization: tensor shape [1] + } ck_tile::HostTensor a_m_k( ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); @@ -237,28 +239,30 @@ int run_gemm_example_with_layouts(int argc, // Create AQ tensor with appropriate shape std::unique_ptr> aq_tensor_ptr = nullptr; - if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) + if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || + QuantMode == ck_tile::QuantType::RowColQuant) { aq_tensor_ptr = std::make_unique>( ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); } - else if(QuantMode == ck_tile::QuantType::RowColQuant) + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { aq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); + ck_tile::host_tensor_descriptor(1, 1, stride_AQ, is_row_major(aq_layout))); } - // Create BQ tensor only for RowColQuant mode + // Create BQ tensor with appropriate shape std::unique_ptr> bq_tensor_ptr = nullptr; - if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped) + if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || + QuantMode == ck_tile::QuantType::RowColQuant) { bq_tensor_ptr = std::make_unique>( ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, is_row_major(bq_layout))); } - else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) { bq_tensor_ptr = std::make_unique>( - ck_tile::host_tensor_descriptor(1, N, stride_BQ, is_row_major(bq_layout))); + ck_tile::host_tensor_descriptor(1, 1, stride_BQ, is_row_major(bq_layout))); } std::random_device rd; @@ -282,7 +286,7 @@ int run_gemm_example_with_layouts(int argc, *bq_tensor_ptr); ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(a_m_k); } - else + else if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped) { if constexpr(std::is_same_v) { @@ -296,12 +300,15 @@ int run_gemm_example_with_layouts(int argc, ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( *aq_tensor_ptr); ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - - if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) - { - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( - *bq_tensor_ptr); - } + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(a_m_k); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *aq_tensor_ptr); + ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}( + *bq_tensor_ptr); } } else if(init_method == 1) @@ -343,7 +350,8 @@ int run_gemm_example_with_layouts(int argc, std::unique_ptr aq_dev_buf_ptr = nullptr; if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { aq_dev_buf_ptr = std::make_unique(aq_tensor_ptr->get_element_space_size_in_bytes()); @@ -351,14 +359,16 @@ int run_gemm_example_with_layouts(int argc, std::unique_ptr bq_dev_buf_ptr = nullptr; if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { bq_dev_buf_ptr = std::make_unique(bq_tensor_ptr->get_element_space_size_in_bytes()); } if constexpr(QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { if constexpr(GemmConfig::PreshuffleQuant) { @@ -398,7 +408,8 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_result.SetZero(); if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) + QuantMode == ck_tile::QuantType::RowColQuant || + QuantMode == ck_tile::QuantType::TensorQuant) { bq_dev_buf_ptr->ToDevice(bq_tensor_ptr->data()); } @@ -412,15 +423,9 @@ int run_gemm_example_with_layouts(int argc, CLayout, QuantGroupSize, QuantMode>(a_m_k_dev_buf, - (QuantMode == ck_tile::QuantType::AQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) - ? aq_dev_buf_ptr.get() - : nullptr, + aq_dev_buf_ptr.get(), b_k_n_dev_buf, - (QuantMode == ck_tile::QuantType::BQuantGrouped || - QuantMode == ck_tile::QuantType::RowColQuant) - ? bq_dev_buf_ptr.get() - : nullptr, + bq_dev_buf_ptr.get(), c_m_n_dev_buf, M, N, @@ -467,7 +472,7 @@ int run_gemm_example_with_layouts(int argc, QuantGroupSize, false>(a_m_k, *bq_tensor_ptr, b_k_n, c_m_n_host_ref); } - else + else if constexpr(QuantMode == ck_tile::QuantType::RowColQuant) { ck_tile::reference_gemm_rowcol_quant( a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref); } + else if constexpr(QuantMode == ck_tile::QuantType::TensorQuant) + { + ck_tile::reference_gemm_tensor_quant( + a_m_k, *aq_tensor_ptr, b_k_n, *bq_tensor_ptr, c_m_n_host_ref); + } const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); diff --git a/include/ck_tile/core/tensor/load_tile.hpp b/include/ck_tile/core/tensor/load_tile.hpp index c7c4702e22..a3620453b4 100644 --- a/include/ck_tile/core/tensor/load_tile.hpp +++ b/include/ck_tile/core/tensor/load_tile.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -158,4 +158,7 @@ CK_TILE_DEVICE auto load_tile_raw(T& /*null_tile*/, const null_tile_window +concept IsLoadableTile = requires { load_tile(std::declval()); }; + } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_gemm.hpp b/include/ck_tile/host/reference/reference_gemm.hpp index d9379b4420..90f68f7e2e 100644 --- a/include/ck_tile/host/reference/reference_gemm.hpp +++ b/include/ck_tile/host/reference/reference_gemm.hpp @@ -180,10 +180,6 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor& a_m_k else v_b = fp32_val.lo; } - else if constexpr(std::is_same_v) - { - v_b = fp8_to_float_raw(b_element_op(b_k_n(k, n))); - } else { v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); @@ -198,7 +194,57 @@ CK_TILE_HOST void reference_gemm_rowcol_quant(const HostTensor& a_m_k }; make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); - std::cout << std::endl; +} + +template +CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor& a_m_k, + const HostTensor& aq_1_1, + const HostTensor& b_k_n, + const HostTensor& bq_1_1, + HostTensor& c_m_n, + const AElementOp& a_element_op = {}, + const BElementOp& b_element_op = {}, + const ACCElementOp& acc_element_op = {}) +{ + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v); + static_assert(std::is_same_v || std::is_same_v); + static_assert(std::is_same_v && std::is_same_v); + const std::size_t M = a_m_k.get_length(0); + const std::size_t N = b_k_n.get_length(1); + const std::size_t K = a_m_k.get_length(1); + + auto f_mn = [&](auto m, auto n) { + // Init accumulator + AccDataType v_acc = 0; + // Get scale for A and scale for B + const AccDataType a_scale = ck_tile::type_convert(aq_1_1(0, 0)); + const AccDataType b_scale = ck_tile::type_convert(bq_1_1(0, 0)); + + // Compute the dot product + for(std::size_t k = 0; k < K; ++k) + { + AccDataType v_a = ck_tile::type_convert(a_element_op(a_m_k(m, k))); + AccDataType v_b = ck_tile::type_convert(b_element_op(b_k_n(k, n))); + + v_acc += v_a * v_b; + } + + v_acc = v_acc * a_scale * b_scale; + + c_m_n(m, n) = ck_tile::type_convert(acc_element_op(v_acc)); + }; + + make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency()); } template && std::is_same_v) { - constexpr auto step = SFC::get_forward_step(iAccess); + // No scaling needed - this is a no-op + } + // Check if scales are scalar AccDataType + else if constexpr(std::is_same_v && + std::is_same_v) + { + // Handle scalar scales + const AccDataType scale_m = scale_m_window; + const AccDataType scale_n = scale_n_window; + tile_elementwise_inout([&](auto& element) { element = element * scale_m * scale_n; }, + lds_tile); + } + // Otherwise, assume they are tile windows that can be loaded + else + { + // Load tiles + const auto scale_m_tile = load_tile(scale_m_window); + const auto scale_n_tile = load_tile(scale_n_window); - move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})}); - move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})}); + // Compute element-wise product in-place i.e. lds_tile = lds_tile * scale_m * scale_n + tile_elementwise_inout( + element_wise::MultiDMultiply{}, lds_tile, lds_tile, scale_m_tile, scale_n_tile); + + // Move scale windows + constexpr index_t num_access = SFC::get_num_of_access(); + if constexpr(iAccess != num_access - 1) + { + constexpr auto step = SFC::get_forward_step(iAccess); + + move_tile_window(scale_m_window, {step.at(number<0>{}), step.at(number<1>{})}); + move_tile_window(scale_n_window, {step.at(number<0>{}), step.at(number<1>{})}); + } } } @@ -452,6 +471,8 @@ struct CShuffleEpilogue // Optional scales (must share the same distribution to match per-thread indexing) constexpr bool has_scales = !std::is_same::value && !std::is_same::value; + constexpr bool has_scalar_scales = + std::is_same_v && std::is_same_v; // Tiles to hold row/col scales when present using SMType = typename GetDataType>::type; @@ -462,8 +483,11 @@ struct CShuffleEpilogue // Build windows only if scales are provided auto scale_m_window = [&]() { - if constexpr(has_scales) + if constexpr(has_scales && !has_scalar_scales) { + static_assert( + IsLoadableTile, + "ScaleM must be a loadable tile"); return make_tile_window(scale_m, dram_tile_distribution); } else @@ -472,8 +496,11 @@ struct CShuffleEpilogue } }(); auto scale_n_window = [&]() { - if constexpr(has_scales) + if constexpr(has_scales && !has_scalar_scales) { + static_assert( + IsLoadableTile, + "ScaleN must be a loadable tile"); return make_tile_window(scale_n, dram_tile_distribution); } else @@ -489,7 +516,7 @@ struct CShuffleEpilogue merge_sequences(sequence<1, NRepeat>{}, c_warp_y_lengths)); // If scales provided, load them with identical distribution - if constexpr(has_scales) + if constexpr(has_scales && IsLoadableTile && IsLoadableTile) { sm_tile = load_tile(scale_m_window); // row scales in permuted layout sn_tile = load_tile(scale_n_window); // col scales in permuted layout @@ -504,7 +531,11 @@ struct CShuffleEpilogue auto emit = [&](index_t out_idx, index_t src_row) { AccDataType v = shuffle_acc.get_thread_buffer()[base + src_row]; - if constexpr(has_scales) + if constexpr(has_scalar_scales) + { + v = static_cast(v * scale_m * scale_n); + } + else if constexpr(has_scales) { // same linear index mapping on the permuted distribution const auto s_m = static_cast(sm_tile.get_thread_buffer()[out_idx]); @@ -595,10 +626,19 @@ struct CShuffleEpilogue number{}); constexpr bool has_scales = - !std::is_same::value && !std::is_same::value; + !std::is_same_v && !std::is_same_v; + constexpr bool has_scalar_scales = + std::is_same_v && std::is_same_v; auto scale_m_window = [&]() { - if constexpr(has_scales) + if constexpr(has_scalar_scales) { + return scale_m; + } + else if constexpr(has_scales) + { + static_assert( + IsLoadableTile, + "ScaleM must be a loadable tile"); return make_tile_window(scale_m, lds_tile.get_tile_distribution()); } else @@ -607,8 +647,15 @@ struct CShuffleEpilogue } }(); auto scale_n_window = [&]() { - if constexpr(has_scales) + if constexpr(has_scalar_scales) { + return scale_n; + } + else if constexpr(has_scales) + { + static_assert( + IsLoadableTile, + "ScaleN must be a loadable tile"); return make_tile_window(scale_n, lds_tile.get_tile_distribution()); } else diff --git a/include/ck_tile/ops/gemm_group_quant.hpp b/include/ck_tile/ops/gemm_group_quant.hpp deleted file mode 100644 index 94b5ab8c3b..0000000000 --- a/include/ck_tile/ops/gemm_group_quant.hpp +++ /dev/null @@ -1,21 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" -#include "ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" -#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp" -#include "ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp" -#include "ck_tile/ops/common/generic_2d_block_shape.hpp" -#include "ck_tile/ops/common/tensor_layout.hpp" -#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm_quant.hpp b/include/ck_tile/ops/gemm_quant.hpp new file mode 100644 index 0000000000..9f90050899 --- /dev/null +++ b/include/ck_tile/ops/gemm_quant.hpp @@ -0,0 +1,21 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp" +#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp" +#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp" +#include "ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/common/utils.hpp" diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp similarity index 100% rename from include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp rename to include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp similarity index 100% rename from include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp rename to include/ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp similarity index 97% rename from include/ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp rename to include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp index 13fa0b8dfb..82bf75a9e3 100644 --- a/include/ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp @@ -12,7 +12,7 @@ #include "ck_tile/core/numeric/integer.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/host/concat.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp" namespace ck_tile { @@ -330,7 +330,6 @@ struct QuantGemmKernel } } - // NOTE: no kernel currently uses BQuant like this: if constexpr(kQuantType == QuantType::BQuantGrouped) { static_assert(std::is_same_v); @@ -890,6 +889,7 @@ struct QuantGemmKernel * @param a_ptr input A pointer * @param b_ptr input B pointer * @param aq_ptr input AQ pointer + * @param bq_ptr input BQ pointer * @param c_ptr output C pointer * @param smem_ptr_0 The start memory pointer of the shared memory block. * @param kargs GEMM kernel arguments @@ -938,7 +938,8 @@ struct QuantGemmKernel return GemmPipeline{}.template operator()( a_block_window, b_block_window, bq_block_window, num_loop, smem_ptr_0); } - else if constexpr(kQuantType == QuantType::RowColQuant) + else if constexpr(kQuantType == QuantType::RowColQuant || + kQuantType == QuantType::TensorQuant) { return GemmPipeline{}.template operator()( a_block_window, b_block_window, num_loop, smem_ptr_0); @@ -964,6 +965,18 @@ struct QuantGemmKernel aq_block_window, bq_block_window); } + else if constexpr(kQuantType == QuantType::TensorQuant) + { + // TODO: why doesn't readfirstlane work here? + // const AccDataType aq_scale = + // __builtin_amdgcn_readfirstlane(type_convert(*aq_ptr)); + // const AccDataType bq_scale = + // __builtin_amdgcn_readfirstlane(type_convert(*bq_ptr)); + const AccDataType aq_scale = type_convert(*aq_ptr); + const AccDataType bq_scale = type_convert(*bq_ptr); + EpiloguePipeline{}( + c_block_window, c_block_tile, c_block_window, smem_ptr_0, aq_scale, bq_scale); + } } CK_TILE_DEVICE void operator()(QuantGemmKernelArgs kargs) const diff --git a/include/ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp similarity index 99% rename from include/ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp rename to include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp index 925ea42678..07c45117e2 100644 --- a/include/ck_tile/ops/gemm_group_quant/kernel/grouped_gemm_quant_kernel.hpp +++ b/include/ck_tile/ops/gemm_quant/kernel/grouped_gemm_quant_kernel.hpp @@ -9,7 +9,7 @@ #include "ck_tile/host/stream_utils.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm_group_quant/kernel/gemm_quant_kernel.hpp" +#include "ck_tile/ops/gemm_quant/kernel/gemm_quant_kernel.hpp" #include "ck_tile/host.hpp" #include diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp similarity index 100% rename from include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp similarity index 100% rename from include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_policy.hpp diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp similarity index 99% rename from include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp index 5ce4268dca..24254013a4 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_v3.hpp @@ -9,7 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/numeric/math.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_aquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp similarity index 100% rename from include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp similarity index 100% rename from include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp similarity index 99% rename from include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp index 8f191f0f94..c27fbf5b50 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp @@ -9,7 +9,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" -#include "ck_tile/ops/gemm_group_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" +#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_base.hpp" #include "ck_tile/host/concat.hpp" namespace ck_tile { diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp similarity index 100% rename from include/ck_tile/ops/gemm_group_quant/pipeline/gemm_group_quant_utils.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp similarity index 87% rename from include/ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp index a2cef2d994..d49204c64d 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/gemm_quant_pipeline_problem.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp @@ -168,17 +168,18 @@ template -using GemmRowColQuantPipelineProblem = GemmQuantPipelineProblemBase; +using GemmRowColTensorQuantPipelineProblem = + GemmQuantPipelineProblemBase; } // namespace ck_tile diff --git a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp similarity index 80% rename from include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp rename to include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp index f505efe4e0..e97eeffb9b 100644 --- a/include/ck_tile/ops/gemm_group_quant/pipeline/tile_gemm_quant_traits.hpp +++ b/include/ck_tile/ops/gemm_quant/pipeline/tile_gemm_quant_traits.hpp @@ -12,9 +12,22 @@ enum struct QuantType : std::uint16_t { AQuantGrouped = 0, BQuantGrouped = 1, - RowColQuant = 2 + RowColQuant = 2, + TensorQuant = 3 }; +std::string quant_type_to_string(QuantType quant_type) +{ + switch(quant_type) + { + case QuantType::AQuantGrouped: return "AQuantGrouped"; + case QuantType::BQuantGrouped: return "BQuantGrouped"; + case QuantType::RowColQuant: return "RowColQuant"; + case QuantType::TensorQuant: return "TensorQuant"; + default: return "Unknown"; + } +} + template ; - bool result = run_cshuffle_epilogue_test(); - EXPECT_TRUE(result) << "Basic CShuffleEpilogue test failed"; + auto result = run_cshuffle_epilogue_test(ScaleType::None); + EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed"; } TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale) @@ -73,8 +73,45 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale) NPerXdl, KPerXdl>; - bool result = run_cshuffle_epilogue_test(true); - EXPECT_TRUE(result) << "Scale CShuffleEpilogue test failed"; + auto result = + run_cshuffle_epilogue_test(ScaleType::RowCol); + EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2"; + EXPECT_FLOAT_EQ(result[1], 4.0F) + << "RowCol CShuffleEpilogue test failed: second element not 2*2"; +} + +TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale) +{ + // Basic test configuration with half_t data types + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using ODataType = ck_tile::half_t; + + constexpr index_t kMPerBlock = 256; + constexpr index_t kNPerBlock = 256; + constexpr index_t MWave = 2; + constexpr index_t NWave = 2; + constexpr index_t MPerXdl = 32; + constexpr index_t NPerXdl = 32; + constexpr index_t KPerXdl = 8; + + using TestProblem = SimpleCShuffleEpilogueProblem; + + auto result = + run_cshuffle_epilogue_test(ScaleType::Tensor); + EXPECT_FLOAT_EQ(result[0], 4.0F) + << "TensorScale CShuffleEpilogue test failed: first element not 2*2=4"; } int main(int argc, char** argv) diff --git a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp index c23957d802..01e6c91c7c 100644 --- a/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp +++ b/test/ck_tile/epilogue/test_cshuffle_epilogue_util.hpp @@ -19,8 +19,15 @@ namespace ck_tile { +enum class ScaleType +{ + None, + RowCol, + Tensor +}; + // Simple test kernel to invoke the CShuffleEpilogue -template +template __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data, float* m_scale, float* n_scale) @@ -61,7 +68,7 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res auto empty_ds = make_tuple(); // Call the epilogue - if constexpr(UseScale) + if constexpr(Scale == ScaleType::RowCol) { const auto m_scale_window = make_tile_window( make_naive_tensor_view( @@ -75,6 +82,10 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res {0, 0}); Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, m_scale_window, n_scale_window); } + else if constexpr(Scale == ScaleType::Tensor) + { + Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, *m_scale, *n_scale); + } else { Epilogue{}(output_tile_window, acc_tile, empty_ds, smem); @@ -113,7 +124,7 @@ using SimpleCShuffleEpilogueProblem = memory_operation_enum::set>; template -bool run_cshuffle_epilogue_test(bool use_scale = false) +auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None) { using ODataType = typename Problem::ODataType; @@ -142,7 +153,7 @@ bool run_cshuffle_epilogue_test(bool use_scale = false) dim3 gridSize(1, 1, 1); dim3 blockSize(kBlockSize, 1, 1); - if(use_scale) + if(scale == ScaleType::RowCol) { float* m_scale; float* n_scale; @@ -155,12 +166,25 @@ bool run_cshuffle_epilogue_test(bool use_scale = false) hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice)); HIP_CHECK_ERROR( hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice)); - test_cshuffle_epilogue_kernel + test_cshuffle_epilogue_kernel + <<>>(device_output, m_scale, n_scale); + } + else if(scale == ScaleType::Tensor) + { + float* m_scale; + float* n_scale; + std::vector h_m_scale(1, 2.0F); + std::vector h_n_scale(1, 1.0F); + HIP_CHECK_ERROR(hipMalloc(&m_scale, sizeof(float))); + HIP_CHECK_ERROR(hipMalloc(&n_scale, sizeof(float))); + HIP_CHECK_ERROR(hipMemcpy(m_scale, h_m_scale.data(), sizeof(float), hipMemcpyHostToDevice)); + HIP_CHECK_ERROR(hipMemcpy(n_scale, h_n_scale.data(), sizeof(float), hipMemcpyHostToDevice)); + test_cshuffle_epilogue_kernel <<>>(device_output, m_scale, n_scale); } else { - test_cshuffle_epilogue_kernel + test_cshuffle_epilogue_kernel <<>>(device_output, nullptr, nullptr); } @@ -172,20 +196,10 @@ bool run_cshuffle_epilogue_test(bool use_scale = false) HIP_CHECK_ERROR(hipMemcpy( host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost)); - // Basic verification - just check that output has a 2, and 4 if using scaling - bool has_2 = - type_convert(host_output[0]) > 1.9F && type_convert(host_output[0]) < 2.1F; - bool scale_has_4 = true; - if(use_scale) - { - scale_has_4 = type_convert(host_output[1]) > 3.9F && - type_convert(host_output[1]) < 4.1F; - } - // Cleanup HIP_CHECK_ERROR(hipFree(device_output)); - return has_2 && scale_has_4; + return host_output; } } // namespace ck_tile diff --git a/test/ck_tile/gemm_block_scale/CMakeLists.txt b/test/ck_tile/gemm_block_scale/CMakeLists.txt index 847ab88644..93a13ba5af 100644 --- a/test/ck_tile/gemm_block_scale/CMakeLists.txt +++ b/test/ck_tile/gemm_block_scale/CMakeLists.txt @@ -6,14 +6,9 @@ endif() list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0) if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") - set(TEST_GEMM_NAME test_tile_gemm_aquant_basic) - set(QUANT_TYPES fp8 bf8 i4fp8 i4bf8 i4f32fp8 i4f32bf8) - - foreach(QUANT_TYPE ${QUANT_TYPES}) - add_gtest_executable(${TEST_GEMM_NAME}_${QUANT_TYPE} test_gemm_aquant_basic_${QUANT_TYPE}.cpp) - target_compile_options(${TEST_GEMM_NAME}_${QUANT_TYPE} PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) - endforeach() - + # Typed Test Suite for GEMM Quantization + add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp) + target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile quant gemm tests for current target") endif() diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp deleted file mode 100644 index 9c4277d879..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_bf8.cpp +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "test_run_gemm_aquant_example.inc" - -int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp deleted file mode 100644 index b0cf55be6f..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_fp8.cpp +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "test_run_gemm_aquant_example.inc" - -int main() { return run_gemm_combinations("fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp deleted file mode 100644 index fd80bf2b06..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4bf8.cpp +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "test_run_gemm_aquant_example.inc" - -int main() { return run_gemm_combinations("i4bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp deleted file mode 100644 index fe8c9c5000..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32bf8.cpp +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "test_run_gemm_aquant_example.inc" - -int main() { return run_gemm_combinations("i4f32bf8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp deleted file mode 100644 index a319d9c2ad..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4f32fp8.cpp +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "test_run_gemm_aquant_example.inc" - -int main() { return run_gemm_combinations("i4f32fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp deleted file mode 100644 index ceb8760435..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_basic_i4fp8.cpp +++ /dev/null @@ -1,6 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "test_run_gemm_aquant_example.inc" - -int main() { return run_gemm_combinations("i4fp8"); } diff --git a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp b/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp deleted file mode 100644 index 83a9e57878..0000000000 --- a/test/ck_tile/gemm_block_scale/test_gemm_aquant_utils.hpp +++ /dev/null @@ -1,243 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck_tile/core.hpp" -#include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/epilogue.hpp" -#include "ck_tile/ops/gemm.hpp" -#include "ck_tile/ops/gemm_group_quant.hpp" - -#define CK_TILE_PIPELINE_PREFILL 1 -#define CK_TILE_PIPELINE_DECODE 2 -#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3 - -template -constexpr ck_tile::index_t get_k_warp_tile() -{ -#if defined(CK_GFX950_SUPPORT) - constexpr bool is_8bit_float = - std::is_same_v || std::is_same_v; - if constexpr(M_Warp_Tile == 32) - return is_8bit_float ? 64 : 16; - else - return is_8bit_float ? 128 : 32; -#else - if constexpr(M_Warp_Tile == 32) - return 16; - else - return 32; -#endif -} - -template -auto calculate_rtol_atol(const ck_tile::index_t K, - const ck_tile::index_t kbatch, - const float max_accumulated_value) -{ - using ComputeType = - std::conditional_t; - // Calculate thresholds - const auto rtol = ck_tile::get_relative_threshold( - ck_tile::integer_divide_ceil(K, kbatch)); - const auto atol = ck_tile::get_absolute_threshold( - max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); - // Calculate error due to split_k accumulation - const auto rtol_split_k = - ck_tile::get_relative_threshold(kbatch); - const auto atol_split_k = ck_tile::get_absolute_threshold( - max_accumulated_value, kbatch); - // Use higher threshold - return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); -} - -class ArgumentsNotSupportedException : public std::logic_error -{ - public: - explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {} -}; - -struct GemmConfigBase -{ - static constexpr bool kPadM = false; - static constexpr bool kPadN = false; - static constexpr bool kPadK = false; - - static constexpr bool PermuteA = false; - static constexpr bool PermuteB = false; - - static constexpr bool TransposeC = false; - static constexpr bool UseStructuredSparsity = false; - - static constexpr int kBlockPerCu = 1; - static constexpr ck_tile::index_t TileParitionerGroupNum = 8; - static constexpr ck_tile::index_t TileParitionerM01 = 4; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; - static constexpr ck_tile::index_t NumWaveGroups = 1; - static constexpr bool PreshuffleQuant = false; - static constexpr bool DoubleSmemBuffer = true; -}; - -template -struct GemmConfigDecode : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE; -}; - -template -struct GemmConfigPrefill : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 128; - static constexpr ck_tile::index_t N_Tile = 128; - static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); - - static constexpr int kBlockPerCu = 2; - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL; -}; - -template -struct GemmConfigPreshuffleQuant : public GemmConfigBase -{ - static constexpr ck_tile::index_t M_Tile = 16; - static constexpr ck_tile::index_t N_Tile = 64; - static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType); - - static constexpr ck_tile::index_t M_Warp = 1; - static constexpr ck_tile::index_t N_Warp = 4; - static constexpr ck_tile::index_t K_Warp = 1; - - static constexpr ck_tile::index_t M_Warp_Tile = 16; - static constexpr ck_tile::index_t N_Warp_Tile = 16; - static constexpr ck_tile::index_t K_Warp_Tile = - get_k_from_preshuffled_warp_tile(); - - static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default; - static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT; - static constexpr bool PreshuffleQuant = true; -}; - -template -struct GemmQuantTypeConfig -{ - using ADataType = ADataType_; - using QDataType = QDataType_; - using BDataType = BDataType_; - using AccDataType = float; - using CDataType = CDataType_; -}; - -template -struct DataTypeTraits; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp64"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int32"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf16"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "fp8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "bf8"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "pk_int4_t"; -}; - -template <> -struct DataTypeTraits -{ - static constexpr const char* name = "int8"; -}; - -auto create_args(int argc, char* argv[]) -{ - ck_tile::ArgParser arg_parser; - arg_parser.insert("m", "3840", "m dimension") - .insert("n", "4096", "n dimension") - .insert("k", "2048", "k dimension") - .insert("a_layout", "R", "A tensor data layout - Row by default") - .insert("aq_layout", "R", "Aq tensor data layout - Row by default") - .insert("b_layout", "C", "B tensor data layout - Column by default") - .insert("c_layout", "R", "C tensor data layout - Row by default") - .insert("stride_a", "0", "Tensor A stride") - .insert("stride_q", "0", "Tensor AQ stride") - .insert("stride_b", "0", "Tensor B stride") - .insert("stride_c", "0", "Tensor C stride") - .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") - .insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8") - .insert("warmup", "50", "number of iterations before benchmark the kernel") - .insert("repeat", "100", "number of iterations to benchmark the kernel") - .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") - .insert("split_k", "1", "splitK value") - .insert("init", "0", "0:random, 1:linear, 2:constant(1)") - .insert("persistent", "0", "0:non-persistent, 1:persistent") - .insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr"); - - bool result = arg_parser.parse(argc, argv); - return std::make_tuple(result, arg_parser); -} - -// host API -float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp new file mode 100644 index 0000000000..ed3231d140 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include +#include +#include +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/check_err.hpp" +#include "ck_tile/host/reference/reference_gemm.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" +#include "ck_tile/ops/gemm_quant.hpp" + +// Forward declarations for quant type-specific implementations +template +struct QuantTypeTraits; + +// Base class for common quant gemm functionality +template +class TestCkTileGemmQuantBase : public ::testing::Test +{ + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using CLayout = std::tuple_element_t<2, Tuple>; + using ADataType = std::tuple_element_t<3, Tuple>; + using BDataType = std::tuple_element_t<4, Tuple>; + using QDataType = std::tuple_element_t<5, Tuple>; + using CDataType = std::tuple_element_t<6, Tuple>; + static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value; + using GemmConfig = std::tuple_element_t<8, Tuple>; + static constexpr uint32_t QuantGroupSize = std::tuple_element_t<9, Tuple>::value; + using AccDataType = float; // accumulate always in float + + // Get the quant-type specific data types from traits + using QuantTraits = QuantTypeTraits; + using ComputeDataType = typename QuantTraits::template ComputeDataType; + + static constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; + static constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; + static constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; + + static constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; + static constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; + static constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; + + static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; + static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; + static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; + + public: + void SetUp() override { static_cast(this)->SetUpQuantTypeSpecific(); } + + void TearDown() override { static_cast(this)->TearDownQuantTypeSpecific(); } + + // Common test execution logic + void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) + { + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + constexpr bool kPreshuffle = false; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = ck_tile::TileGemmQuantTraits; + + // Let the derived class create the appropriate pipeline and epilogue + static_cast(this) + ->template run_quant_gemm_impl( + args, s); + } + + void RunTest(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) + { + // Generate test data and run the kernel + static_cast(this)->run_test_with_validation(M, N, K); + } + + // Helper function to check layout + template + static constexpr auto is_row_major(Layout) + { + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; + } + + // Tolerance calculation function for validation + template + auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) + { + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = + ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); + } +}; + +// Define generic QuantTypeTraits template (will be specialized) +template +struct QuantTypeTraits +{ + static_assert(QT == ck_tile::QuantType::AQuantGrouped || + QT == ck_tile::QuantType::BQuantGrouped || + QT == ck_tile::QuantType::RowColQuant || + QT == ck_tile::QuantType::TensorQuant, + "Unsupported quantization type"); +}; + +// Specialization for AQuantGrouped +template <> +struct QuantTypeTraits +{ + template + using ComputeDataType = BDataType; // For AQuant, compute type is BDataType + + static constexpr const char* name = "aquant"; +}; + +// Specialization for BQuantGrouped +template <> +struct QuantTypeTraits +{ + template + using ComputeDataType = ADataType; // For BQuant, compute type is ADataType + + static constexpr const char* name = "bquant"; +}; + +// Specialization for RowColQuant +template <> +struct QuantTypeTraits +{ + template + using ComputeDataType = ADataType; // For RowColQuant, compute type is ADataType + + static constexpr const char* name = "rowcol"; +}; + +// Specialization for TensorQuant +template <> +struct QuantTypeTraits +{ + template + using ComputeDataType = ADataType; // For TensorQuant, compute type is ADataType + + static constexpr const char* name = "tensor"; +}; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp new file mode 100644 index 0000000000..5fc6b2f15c --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp @@ -0,0 +1,919 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "test_gemm_quant_base.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t NumWaveGroups = 1; + static constexpr bool PreshuffleQuant = false; + static constexpr bool DoubleSmemBuffer = false; + + // Default GEMM tile sizes for tests + static constexpr ck_tile::index_t M_Tile = 16; + static constexpr ck_tile::index_t N_Tile = 64; + static constexpr ck_tile::index_t K_Tile = 256; + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 4; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = 32; +}; + +template +class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase> +{ + using Base = TestCkTileGemmQuantBase>; + friend Base; + + public: + using typename Base::AccDataType; + using typename Base::ADataType; + using typename Base::ALayout; + using typename Base::BDataType; + using typename Base::BLayout; + using typename Base::CDataType; + using typename Base::CLayout; + using typename Base::ComputeDataType; + using typename Base::QDataType; + + static constexpr auto QuantType = Base::QuantType; + static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize; + + protected: + void SetUpQuantTypeSpecific() {} + void TearDownQuantTypeSpecific() {} + + // AQuant-specific data generation + void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) + { + const ck_tile::index_t stride_A = K; + const ck_tile::index_t stride_B = K; + const ck_tile::index_t stride_C = M; + + // AQuant uses grouped quantization for A matrix + const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize); + const ck_tile::index_t stride_AQ = + ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{})); + + // Generate test data + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + ck_tile::HostTensor aq_m_aqk( + ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(ALayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + + // Initialize data with random values + if constexpr(std::is_same_v) + { + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(a_m_k); + } + else + { + ck_tile::FillUniformDistribution{-2.0f, 3.0f}(a_m_k); + } + ck_tile::FillUniformDistribution{-5.0f, 5.0f}(b_k_n); + ck_tile::FillUniformDistribution{-2.0f, 2.0f}(aq_m_aqk); + + // Allocate device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); + ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType)); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); + ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType)); + + // Copy to device + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor temp = a_m_k; + ck_tile::permute_vectors_i4x4_b(temp); + a_m_k_dev_buf.ToDevice(temp.data()); + } + else + { + a_m_k_dev_buf.ToDevice(a_m_k.data()); + } + aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + + // Create args for kernel execution + ck_tile::QuantGemmHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr + aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales) + nullptr, // bq_ptr (not used for AQuant) + 1, // k_batch + M, + N, + K, // M, N, K + AQK, // QK_A + 0, // QK_B (not used for AQuant) + stride_A, + stride_B, + stride_C, + stride_AQ, + 0 // strides + }; + + // Run the kernel + ck_tile::stream_config stream_config{}; + this->invoke_quant_gemm(args, stream_config); + + // Validation using reference implementation + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Run reference AQuant implementation + ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); + + // Get device result + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data()); + + // Calculate error tolerances + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + this->template calculate_rtol_atol( + K, 1, max_accumulated_value); + + // Validate results + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N + << ", K=" << K; + + if(!pass) + { + std::cout << "AQuantGrouped - Relative error threshold: " + << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + } + + private: + // AQuant-specific pipeline implementation + template + void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args, + const ck_tile::stream_config& s) + { + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr bool transpose_c = false; + + using PipelineProblem = + ck_tile::GemmAQuantPipelineProblem; + + using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set>>; + + using Kernel = ck_tile::QuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Arguments not supported for AQuant kernel"); + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + }; + + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } +}; + +// BQuant-specific test fixture +template +class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase> +{ + using Base = TestCkTileGemmQuantBase>; + friend Base; + + public: + using typename Base::AccDataType; + using typename Base::ADataType; + using typename Base::ALayout; + using typename Base::BDataType; + using typename Base::BLayout; + using typename Base::CDataType; + using typename Base::CLayout; + using typename Base::ComputeDataType; + using typename Base::QDataType; + + static constexpr auto QuantType = Base::QuantType; + static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize; + + protected: + void SetUpQuantTypeSpecific() {} + void TearDownQuantTypeSpecific() {} + + void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) + { + const ck_tile::index_t stride_A = K; + const ck_tile::index_t stride_B = K; + const ck_tile::index_t stride_C = M; + + // BQuant uses grouped quantization for B matrix + const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize); + const ck_tile::index_t stride_BQ = BQK; + + // Generate test data + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + ck_tile::HostTensor bq_bqk_n( + ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, this->is_row_major(BLayout{}))); + + // Initialize data with random values + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); + ck_tile::FillUniformDistribution{0.f, 1.f}(b_k_n); + ck_tile::FillUniformDistribution{0.001f, 0.01f}(bq_bqk_n); + + // Allocate device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); + ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size() * sizeof(QDataType)); + ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType)); + + // Copy to device + a_m_k_dev_buf.ToDevice(a_m_k.data()); + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor temp = b_k_n; + ck_tile::permute_vectors_i4x4_b(temp); + b_k_n_dev_buf.ToDevice(temp.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data()); + + // Create args for kernel execution + ck_tile::QuantGemmHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr + nullptr, // aq_ptr (not used for BQuant) + bq_bqk_n_dev_buf.GetDeviceBuffer(), // bq_ptr (scales) + 1, // k_batch + M, + N, + K, // M, N, K + 0, // QK_A (not used for BQuant) + BQK, // QK_B + stride_A, + stride_B, + stride_C, + 0, + stride_BQ // strides + }; + + // Run the kernel + ck_tile::stream_config stream_config{}; + this->invoke_quant_gemm(args, stream_config); + + // Validation using reference implementation + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Run reference BQuant implementation + ck_tile::reference_gemm_quant(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref); + + // Get device result + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data()); + + // Calculate error tolerances + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + this->template calculate_rtol_atol( + K, 1, max_accumulated_value); + + // Validate results + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N + << ", K=" << K; + + if(!pass) + { + std::cout << "BQuantGrouped - Relative error threshold: " + << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + } + + private: + // BQuant-specific pipeline implementation + template + void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args, + const ck_tile::stream_config& s) + { + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + + using PipelineProblem = + ck_tile::GemmBQuantPipelineProblem; + + using GemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + false, // transpose_c + ck_tile::memory_operation_enum::set>>; + + using Kernel = ck_tile::QuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Arguments not supported for BQuant kernel"); + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + }; + + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } +}; + +// RowColQuant-specific test fixture +template +class TestCkTileGemmRowColQuant + : public TestCkTileGemmQuantBase> +{ + using Base = TestCkTileGemmQuantBase>; + friend Base; + + public: + using typename Base::AccDataType; + using typename Base::ADataType; + using typename Base::ALayout; + using typename Base::BDataType; + using typename Base::BLayout; + using typename Base::CDataType; + using typename Base::CLayout; + using typename Base::ComputeDataType; + using typename Base::QDataType; + + static constexpr auto QuantType = Base::QuantType; + static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize; + + protected: + void SetUpQuantTypeSpecific() {} + void TearDownQuantTypeSpecific() {} + + void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) + { + const ck_tile::index_t stride_A = K; + const ck_tile::index_t stride_B = K; + const ck_tile::index_t stride_C = M; + + // RowColQuant uses per-row and per-column scales + const ck_tile::index_t stride_row_scales = 1; + const ck_tile::index_t stride_col_scales = 1; + + // Generate test data + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + ck_tile::HostTensor row_scales_m(ck_tile::host_tensor_descriptor( + M, 1, stride_row_scales, ck_tile::bool_constant{})); + ck_tile::HostTensor col_scales_n(ck_tile::host_tensor_descriptor( + N, 1, stride_col_scales, ck_tile::bool_constant{})); + + // Initialize data with random values + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(b_k_n); + ck_tile::FillUniformDistribution{0.001f, 0.01f}(row_scales_m); + ck_tile::FillUniformDistribution{0.001f, 0.01f}(col_scales_n); + + // Allocate device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); + ck_tile::DeviceMem row_scales_dev_buf(row_scales_m.get_element_space_size() * + sizeof(QDataType)); + ck_tile::DeviceMem col_scales_dev_buf(col_scales_n.get_element_space_size() * + sizeof(QDataType)); + ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType)); + + // Copy to device + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + row_scales_dev_buf.ToDevice(row_scales_m.data()); + col_scales_dev_buf.ToDevice(col_scales_n.data()); + + // Create args for kernel execution + ck_tile::QuantGemmHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr + row_scales_dev_buf.GetDeviceBuffer(), // aq_ptr (row scales) + col_scales_dev_buf.GetDeviceBuffer(), // bq_ptr (col scales) + 1, // k_batch + M, + N, + K, // M, N, K + 1, // QK_A (row scales) + 1, // QK_B (col scales) + stride_A, + stride_B, + stride_C, + stride_row_scales, + stride_col_scales // strides + }; + + // Run the kernel + ck_tile::stream_config stream_config{}; + this->invoke_quant_gemm(args, stream_config); + + // Validation using reference implementation + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Run reference RowColQuant implementation + ck_tile::reference_gemm_rowcol_quant( + a_m_k, row_scales_m, b_k_n, col_scales_n, c_m_n_host_ref); + + // Get device result + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data()); + + // Calculate error tolerances + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + this->template calculate_rtol_atol( + K, 1, max_accumulated_value); + + // Validate results + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass) << "RowColQuant validation failed with M=" << M << ", N=" << N + << ", K=" << K; + + if(!pass) + { + std::cout << "RowColQuant - Relative error threshold: " + << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + } + + private: + // RowColQuant-specific pipeline implementation + template + void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args, + const ck_tile::stream_config& s) + { + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr bool transpose_c = false; + + using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem< + ADataType, + BDataType, + AccDataType, + AccDataType, + CodegenGemmShape, + CodegenGemmTraits, + transpose_c, + ComputeDataType, + ck_tile::GemmPipelineScheduler::Intrawave, + has_hot_loop_v, + tail_number_v>; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set>>; + + using Kernel = ck_tile::QuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Arguments not supported for RowColQuant kernel"); + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + }; + + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } +}; + +// TensorQuant-specific test fixture +template +class TestCkTileGemmTensorQuant + : public TestCkTileGemmQuantBase> +{ + using Base = TestCkTileGemmQuantBase>; + friend Base; + + public: + using typename Base::AccDataType; + using typename Base::ADataType; + using typename Base::ALayout; + using typename Base::BDataType; + using typename Base::BLayout; + using typename Base::CDataType; + using typename Base::CLayout; + using typename Base::ComputeDataType; + using typename Base::QDataType; + + static constexpr auto QuantType = Base::QuantType; + static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize; + + protected: + void SetUpQuantTypeSpecific() {} + void TearDownQuantTypeSpecific() {} + + void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K) + { + const ck_tile::index_t stride_A = K; + const ck_tile::index_t stride_B = K; + const ck_tile::index_t stride_C = M; + + // TensorQuant uses single scalar scale for each tensor + const ck_tile::index_t stride_scale_a = 1; + const ck_tile::index_t stride_scale_b = 1; + + // Generate test data + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{}))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{}))); + ck_tile::HostTensor scale_a( + ck_tile::host_tensor_descriptor(1, 1, stride_scale_a, ck_tile::bool_constant{})); + ck_tile::HostTensor scale_b( + ck_tile::host_tensor_descriptor(1, 1, stride_scale_b, ck_tile::bool_constant{})); + + // Initialize data with random values + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(a_m_k); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(b_k_n); + ck_tile::FillUniformDistribution{0.001f, 0.01f}(scale_a); + ck_tile::FillUniformDistribution{0.001f, 0.01f}(scale_b); + + // Allocate device memory + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType)); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType)); + ck_tile::DeviceMem scale_a_dev_buf(scale_a.get_element_space_size() * sizeof(QDataType)); + ck_tile::DeviceMem scale_b_dev_buf(scale_b.get_element_space_size() * sizeof(QDataType)); + ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType)); + + // Copy to device + a_m_k_dev_buf.ToDevice(a_m_k.data()); + b_k_n_dev_buf.ToDevice(b_k_n.data()); + scale_a_dev_buf.ToDevice(scale_a.data()); + scale_b_dev_buf.ToDevice(scale_b.data()); + + // Create args for kernel execution + ck_tile::QuantGemmHostArgs args{ + a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr + b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr + c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr + scale_a_dev_buf.GetDeviceBuffer(), // aq_ptr (scale A) + scale_b_dev_buf.GetDeviceBuffer(), // bq_ptr (scale B) + 1, // k_batch + M, + N, + K, // M, N, K + 1, // QK_A (tensor scale) + 1, // QK_B (tensor scale) + stride_A, + stride_B, + stride_C, + stride_scale_a, + stride_scale_b // strides + }; + + // Run the kernel + ck_tile::stream_config stream_config{}; + this->invoke_quant_gemm(args, stream_config); + + // Validation using reference implementation + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + // Run reference TensorQuant implementation + ck_tile::reference_gemm_tensor_quant( + a_m_k, scale_a, b_k_n, scale_b, c_m_n_host_ref); + + // Get device result + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{}))); + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data()); + + // Calculate error tolerances + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = + this->template calculate_rtol_atol( + K, 1, max_accumulated_value); + + // Validate results + bool pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + EXPECT_TRUE(pass) << "TensorQuant validation failed with M=" << M << ", N=" << N + << ", K=" << K; + + if(!pass) + { + std::cout << "TensorQuant - Relative error threshold: " + << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + } + } + + private: + // TensorQuant-specific pipeline implementation + template + void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args, + const ck_tile::stream_config& s) + { + using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; + + using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; + + const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr bool transpose_c = false; + + using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem< + ADataType, + BDataType, + AccDataType, + AccDataType, + CodegenGemmShape, + CodegenGemmTraits, + transpose_c, + ComputeDataType, + ck_tile::GemmPipelineScheduler::Intrawave, + has_hot_loop_v, + tail_number_v>; + + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + Base::M_Warp, + Base::N_Warp, + Base::M_Warp_Tile, + Base::N_Warp_Tile, + Base::K_Warp_Tile, + transpose_c, + ck_tile::memory_operation_enum::set>>; + + using Kernel = ck_tile::QuantGemmKernel; + + auto kargs = Kernel::MakeKernelArgs(args); + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + const dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw std::runtime_error("Arguments not supported for TensorQuant kernel"); + } + + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + }; + + return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); + } +}; diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp new file mode 100644 index 0000000000..1926b7cd0f --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp @@ -0,0 +1,64 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "ck_tile/host.hpp" +#include "ck_tile/ops/gemm.hpp" + +#include +#include + +#include "test_gemm_quant_fixtures.hpp" + +// Type aliases for readability +using RowMajor = ck_tile::tensor_layout::gemm::RowMajor; +using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor; +using FP8 = ck_tile::fp8_t; +using BF8 = ck_tile::bf8_t; +using Half = ck_tile::half_t; +using PkInt4 = ck_tile::pk_int4_t; +using AQuantGrouped = std::integral_constant; +using BQuantGrouped = std::integral_constant; +using RowColQuant = std::integral_constant; +using TensorQuant = std::integral_constant; +using GroupSize = std::integral_constant; + +// Type combinations for each quantization type +// clang-format off +using AQuantTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// clang-format off +using BQuantTypes = ::testing::Types< + std::tuple, + std::tuple, + std::tuple, + std::tuple +>; +// clang-format on + +// clang-format off +using RowColQuantTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// clang-format off +using TensorQuantTypes = ::testing::Types< + std::tuple, + std::tuple +>; +// clang-format on + +// Test suites for each quantization type +TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes); +TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes); +TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes); +TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes); + +#include "test_gemm_quant_ut_cases.inc" diff --git a/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc b/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc new file mode 100644 index 0000000000..9b07afa2b3 --- /dev/null +++ b/test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc @@ -0,0 +1,28 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +// AQuant tests +TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} + +// BQuant tests +TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} + +// RowColQuant tests +TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} + +// TensorQuant tests +TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest) +{ + this->run_test_with_validation(1024, 1024, 1024); +} diff --git a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc b/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc deleted file mode 100644 index dbe652ac62..0000000000 --- a/test/ck_tile/gemm_block_scale/test_run_gemm_aquant_example.inc +++ /dev/null @@ -1,616 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include "ck_tile/core/config.hpp" -#include "ck_tile/host.hpp" -#include "test_gemm_aquant_utils.hpp" -#include "ck_tile/host/permute_pk_int4.hpp" - -template -float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s) -{ - constexpr bool kPadM = false; - constexpr bool kPadN = false; - constexpr bool kPadK = false; - - constexpr int kBlockPerCu = 1; - - static_assert(std::is_same_v); - - constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile; - constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile; - constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile; - - constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp; - constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp; - constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp; - - constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile; - constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile; - constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile; - - using CodegenGemmShape = - ck_tile::TileGemmShape, - ck_tile::sequence, - ck_tile::sequence>; - - using TilePartitioner = ck_tile::GemmTile1DPartitioner; - - using CodegenGemmTraits = ck_tile::TileGemmQuantTraits; - - using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase; - - using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3; - - const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile; - const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); - const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); - const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); - constexpr bool transposed_warp_gemm = false; - - const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) { - constexpr bool has_hot_loop_v = has_hot_loop_.value; - constexpr auto tail_number_v = tail_number_.value; - - using CodegenPipelineProblem = - ck_tile::GemmAQuantPipelineProblem; - using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3; - using GemmEpilogue = ck_tile::CShuffleEpilogue< - ck_tile::CShuffleEpilogueProblem, - AccDataType, - CDataType, - ck_tile::tuple<>, - CLayout, - ck_tile::element_wise::PassThrough, - TilePartitioner::MPerBlock, - TilePartitioner::NPerBlock, - M_Warp, - N_Warp, - M_Warp_Tile, - N_Warp_Tile, - K_Warp_Tile, - transposed_warp_gemm, - ck_tile::memory_operation_enum::set>>; - using Kernel = ck_tile::QuantGemmKernel; - - auto kargs = Kernel::MakeKernelArgs(args); - - const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); - const dim3 blocks = Kernel::BlockSize(); - - if(args.k_batch != 1) - { - throw std::runtime_error("split-k is not supported yet!"); - } - - if(!Kernel::IsSupportedArgument(kargs)) - { - throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n"); - } - - if(s.log_level_ > 0) - { - std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' - << "shape: " << CodegenGemmShape::GetName() << '\n' - << "problem: " << CodegenPipelineProblem::GetName() << '\n' - << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' - << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" - << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" - << std::endl; - } - - float ave_time = ck_tile::launch_kernel( - s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); - - return ave_time; - }; - return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num); -} - -template -static constexpr inline auto is_row_major(Layout layout_) -{ - return ck_tile::bool_constant, - ck_tile::tensor_layout::gemm::RowMajor>>{}; -} - -template -float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, - ck_tile::DeviceMem& aq_m_aqk_dev_buf, - ck_tile::DeviceMem& b_k_n_dev_buf, - ck_tile::DeviceMem& c_m_n_dev_buf, - ck_tile::index_t M, - ck_tile::index_t N, - ck_tile::index_t K, - ck_tile::index_t AQK, - ck_tile::index_t stride_A, - ck_tile::index_t stride_AQ, - ck_tile::index_t stride_B, - ck_tile::index_t stride_C, - ck_tile::index_t kbatch, - int n_warmup, - int n_repeat) -{ - ck_tile::QuantGemmHostArgs args; - args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer(); - args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer(); - args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer(); - args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer(); - args.k_batch = kbatch; - args.M = M; - args.N = N; - args.K = K; - args.QK_A = AQK; - args.stride_A = stride_A; - args.stride_B = stride_B; - args.stride_C = stride_C; - args.stride_AQ = stride_AQ; - - float ave_time = gemm_calc_aquant( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); - - std::size_t flop = std::size_t(2) * M * N * K; - std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK + - sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; - float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_byte / 1.E6 / ave_time; - - std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K - << " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B - << " StrideC =" << stride_C << " A_Layout =" << ALayout::name - << " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name - << " A_Type = " << DataTypeTraits::name - << " AQ_Type = " << DataTypeTraits::name - << " B_Type = " << DataTypeTraits::name - << " Acc_Type = " << DataTypeTraits::name - << " C_Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " - << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; - - return ave_time; -} - -template -bool run_gemm_test_with_layouts(int argc, - char* argv[], - const ALayout a_layout = ALayout{}, - const AQLayout aq_layout = AQLayout{}, - const BLayout b_layout = BLayout{}, - [[maybe_unused]] const CLayout c_layout = CLayout{}) -{ - auto [result, arg_parser] = create_args(argc, argv); - if(!result) - return false; - - using ADataType = typename TypeConfig::ADataType; - using AQDataType = typename TypeConfig::QDataType; - using BDataType = typename TypeConfig::BDataType; - using AccDataType = typename TypeConfig::AccDataType; - using CDataType = typename TypeConfig::CDataType; - - ck_tile::index_t M = arg_parser.get_int("m"); - ck_tile::index_t N = arg_parser.get_int("n"); - ck_tile::index_t K = arg_parser.get_int("k"); - - if(K % QuantGroupSize != 0) - { - throw std::runtime_error("K must be aligned with QuantGroupSize"); - } - - ck_tile::index_t AQK = K / QuantGroupSize; - - ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); - ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q"); - ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); - ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); - - ck_tile::index_t kbatch = arg_parser.get_int("split_k"); - int n_warmup = arg_parser.get_int("warmup"); - int n_repeat = arg_parser.get_int("repeat"); - ck_tile::index_t init_method = arg_parser.get_int("init"); - - stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); - stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout)); - stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); - stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); - - ck_tile::HostTensor a_m_k( - ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); - ck_tile::HostTensor aq_m_aqk( - ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout))); - ck_tile::HostTensor b_k_n( - ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); - ck_tile::HostTensor c_m_n_dev_result( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution fill_seed(0, 500); - - if(init_method == 0) - { - if constexpr(std::is_same_v) - { - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}( - a_m_k); - } - else - { - ck_tile::FillUniformDistribution{-2.0f, 3.0f, fill_seed(gen)}(a_m_k); - } - ck_tile::FillUniformDistribution{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk); - ck_tile::FillUniformDistribution{-5.0f, 5.0f, fill_seed(gen)}(b_k_n); - } - else if(init_method == 1) - { - std::cout << "Monotonic initialization is not supported." << std::endl; - return true; - } - else if(init_method == 2) - { - ck_tile::FillConstant{static_cast(0x22)}(a_m_k); - ck_tile::FillConstant{static_cast(0.5f)}(aq_m_aqk); - ck_tile::FillConstant{static_cast(0x38)}(b_k_n); - } - else - { - a_m_k.SetZero(); - aq_m_aqk.SetZero(); - b_k_n.SetZero(); - } - - ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); - ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); - ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - - if constexpr(std::is_same_v) - { - // Permute vector pk_i4x4 data for device implementation - ck_tile::HostTensor a_m_k_dev = a_m_k; - ck_tile::permute_vectors_i4x4_b(a_m_k_dev); - a_m_k_dev_buf.ToDevice(a_m_k_dev.data()); - } - else - { - a_m_k_dev_buf.ToDevice(a_m_k.data()); - } - aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data()); - b_k_n_dev_buf.ToDevice(b_k_n.data()); - c_m_n_dev_buf.SetZero(); - c_m_n_dev_result.SetZero(); - - invoke_gemm(a_m_k_dev_buf, - aq_m_aqk_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - AQK, - stride_A, - stride_AQ, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); - - c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); - bool pass = true; - - if(arg_parser.get_int("v") == 1) - { - ck_tile::HostTensor c_m_n_host_ref( - ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); - c_m_n_host_ref.SetZero(); - - ck_tile::reference_gemm_quant(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref); - const float max_accumulated_value = - *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol( - K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, - c_m_n_host_ref, - "Error: Incorrect results!", - rtol_atol.at(ck_tile::number<0>{}), - rtol_atol.at(ck_tile::number<1>{})); - - if(!pass) - { - std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) - << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) - << std::endl; - } - std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl; - } - else if(arg_parser.get_int("v") == 2) - { - std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl; - return false; - } - - return pass; -} - -template -bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) -{ - using Row = ck_tile::tensor_layout::gemm::RowMajor; - using Col = ck_tile::tensor_layout::gemm::ColumnMajor; - - if constexpr(std::is_same_v || - std::is_same_v || - std::is_same_v) - { - if(a_layout == "R" && b_layout == "C") - { - return run_gemm_test_with_layouts( - argc, argv, Row{}, Row{}, Col{}, Row{}); - } - else - { - throw std::runtime_error("Unsupported memory layout for the input matrices!"); - } - } - else - { - throw std::runtime_error("Unsupported data type for A."); - } - - return true; -} - -template