[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)

* rename gemm_group_quant to gemm_quant

* Add TensorWise quant mode

* Cshuffle epilogue tests with tensor scaling

* Add tensor quant to example

* Don't use readfirstlane for reading scales - doesn't work for some reason

* Add to changelog

* revert include - from a merge problem?

* revert common.hpp include

* revert host.hpp include

* remove unused utility function

* rename quant pipeline problem

* refactor quant tests

* remove aquant utils

* use TEST_F

* fix all tests by changing gemm config

* Use typed tests

* fix copyright

[ROCm/composable_kernel commit: 4363a82bd6]
This commit is contained in:
Sami Remes
2025-09-20 02:52:35 +03:00
committed by GitHub
parent ee43f0f0be
commit 8d2a444c55
39 changed files with 1555 additions and 1056 deletions

View File

@@ -41,8 +41,8 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTest)
NPerXdl,
KPerXdl>;
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>();
EXPECT_TRUE(result) << "Basic CShuffleEpilogue test failed";
auto result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::None);
EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed";
}
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
@@ -73,8 +73,45 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
NPerXdl,
KPerXdl>;
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(true);
EXPECT_TRUE(result) << "Scale CShuffleEpilogue test failed";
auto result =
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::RowCol);
EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2";
EXPECT_FLOAT_EQ(result[1], 4.0F)
<< "RowCol CShuffleEpilogue test failed: second element not 2*2";
}
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale)
{
// Basic test configuration with half_t data types
using ADataType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using AccDataType = float;
using ODataType = ck_tile::half_t;
constexpr index_t kMPerBlock = 256;
constexpr index_t kNPerBlock = 256;
constexpr index_t MWave = 2;
constexpr index_t NWave = 2;
constexpr index_t MPerXdl = 32;
constexpr index_t NPerXdl = 32;
constexpr index_t KPerXdl = 8;
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
BDataType,
AccDataType,
ODataType,
kMPerBlock,
kNPerBlock,
MWave,
NWave,
MPerXdl,
NPerXdl,
KPerXdl>;
auto result =
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::Tensor);
EXPECT_FLOAT_EQ(result[0], 4.0F)
<< "TensorScale CShuffleEpilogue test failed: first element not 2*2=4";
}
int main(int argc, char** argv)

View File

@@ -19,8 +19,15 @@
namespace ck_tile {
enum class ScaleType
{
None,
RowCol,
Tensor
};
// Simple test kernel to invoke the CShuffleEpilogue
template <typename Problem, index_t M, index_t N, bool UseScale>
template <typename Problem, index_t M, index_t N, ScaleType Scale>
__global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data,
float* m_scale,
float* n_scale)
@@ -61,7 +68,7 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
auto empty_ds = make_tuple();
// Call the epilogue
if constexpr(UseScale)
if constexpr(Scale == ScaleType::RowCol)
{
const auto m_scale_window = make_tile_window(
make_naive_tensor_view<address_space_enum::global>(
@@ -75,6 +82,10 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
{0, 0});
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, m_scale_window, n_scale_window);
}
else if constexpr(Scale == ScaleType::Tensor)
{
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, *m_scale, *n_scale);
}
else
{
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem);
@@ -113,7 +124,7 @@ using SimpleCShuffleEpilogueProblem =
memory_operation_enum::set>;
template <typename Problem, index_t M, index_t N>
bool run_cshuffle_epilogue_test(bool use_scale = false)
auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None)
{
using ODataType = typename Problem::ODataType;
@@ -142,7 +153,7 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
dim3 gridSize(1, 1, 1);
dim3 blockSize(kBlockSize, 1, 1);
if(use_scale)
if(scale == ScaleType::RowCol)
{
float* m_scale;
float* n_scale;
@@ -155,12 +166,25 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(
hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice));
test_cshuffle_epilogue_kernel<Problem, M, N, true>
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::RowCol>
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
}
else if(scale == ScaleType::Tensor)
{
float* m_scale;
float* n_scale;
std::vector<float> h_m_scale(1, 2.0F);
std::vector<float> h_n_scale(1, 1.0F);
HIP_CHECK_ERROR(hipMalloc(&m_scale, sizeof(float)));
HIP_CHECK_ERROR(hipMalloc(&n_scale, sizeof(float)));
HIP_CHECK_ERROR(hipMemcpy(m_scale, h_m_scale.data(), sizeof(float), hipMemcpyHostToDevice));
HIP_CHECK_ERROR(hipMemcpy(n_scale, h_n_scale.data(), sizeof(float), hipMemcpyHostToDevice));
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::Tensor>
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
}
else
{
test_cshuffle_epilogue_kernel<Problem, M, N, false>
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::None>
<<<gridSize, blockSize>>>(device_output, nullptr, nullptr);
}
@@ -172,20 +196,10 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
HIP_CHECK_ERROR(hipMemcpy(
host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost));
// Basic verification - just check that output has a 2, and 4 if using scaling
bool has_2 =
type_convert<float>(host_output[0]) > 1.9F && type_convert<float>(host_output[0]) < 2.1F;
bool scale_has_4 = true;
if(use_scale)
{
scale_has_4 = type_convert<float>(host_output[1]) > 3.9F &&
type_convert<float>(host_output[1]) < 4.1F;
}
// Cleanup
HIP_CHECK_ERROR(hipFree(device_output));
return has_2 && scale_has_4;
return host_output;
}
} // namespace ck_tile

View File

@@ -6,14 +6,9 @@ endif()
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
set(TEST_GEMM_NAME test_tile_gemm_aquant_basic)
set(QUANT_TYPES fp8 bf8 i4fp8 i4bf8 i4f32fp8 i4f32bf8)
foreach(QUANT_TYPE ${QUANT_TYPES})
add_gtest_executable(${TEST_GEMM_NAME}_${QUANT_TYPE} test_gemm_aquant_basic_${QUANT_TYPE}.cpp)
target_compile_options(${TEST_GEMM_NAME}_${QUANT_TYPE} PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
endforeach()
# Typed Test Suite for GEMM Quantization
add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp)
target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
else()
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
endif()

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("bf8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("fp8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4bf8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4f32bf8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4f32fp8"); }

View File

@@ -1,6 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "test_run_gemm_aquant_example.inc"
int main() { return run_gemm_combinations("i4fp8"); }

View File

@@ -1,243 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_group_quant.hpp"
#define CK_TILE_PIPELINE_PREFILL 1
#define CK_TILE_PIPELINE_DECODE 2
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
constexpr ck_tile::index_t get_k_warp_tile()
{
#if defined(CK_GFX950_SUPPORT)
constexpr bool is_8bit_float =
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
if constexpr(M_Warp_Tile == 32)
return is_8bit_float ? 64 : 16;
else
return is_8bit_float ? 128 : 32;
#else
if constexpr(M_Warp_Tile == 32)
return 16;
else
return 32;
#endif
}
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
class ArgumentsNotSupportedException : public std::logic_error
{
public:
explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {}
};
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = true;
};
template <typename PrecType>
struct GemmConfigDecode : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
};
template <typename PrecType>
struct GemmConfigPrefill : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 128;
static constexpr ck_tile::index_t N_Tile = 128;
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
static constexpr int kBlockPerCu = 2;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
};
template <typename PrecType>
struct GemmConfigPreshuffleQuant : public GemmConfigBase
{
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile =
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
static constexpr bool PreshuffleQuant = true;
};
template <typename ADataType_,
typename BDataType_ = ADataType_,
typename CDataType_ = ADataType_,
typename QDataType_ = float>
struct GemmQuantTypeConfig
{
using ADataType = ADataType_;
using QDataType = QDataType_;
using BDataType = BDataType_;
using AccDataType = float;
using CDataType = CDataType_;
};
template <typename T>
struct DataTypeTraits;
template <>
struct DataTypeTraits<float>
{
static constexpr const char* name = "fp32";
};
template <>
struct DataTypeTraits<double>
{
static constexpr const char* name = "fp64";
};
template <>
struct DataTypeTraits<int32_t>
{
static constexpr const char* name = "int32";
};
template <>
struct DataTypeTraits<ck_tile::half_t>
{
static constexpr const char* name = "fp16";
};
template <>
struct DataTypeTraits<ck_tile::bf16_t>
{
static constexpr const char* name = "bf16";
};
template <>
struct DataTypeTraits<ck_tile::fp8_t>
{
static constexpr const char* name = "fp8";
};
template <>
struct DataTypeTraits<ck_tile::bf8_t>
{
static constexpr const char* name = "bf8";
};
template <>
struct DataTypeTraits<ck_tile::pk_int4_t>
{
static constexpr const char* name = "pk_int4_t";
};
template <>
struct DataTypeTraits<ck_tile::int8_t>
{
static constexpr const char* name = "int8";
};
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3840", "m dimension")
.insert("n", "4096", "n dimension")
.insert("k", "2048", "k dimension")
.insert("a_layout", "R", "A tensor data layout - Row by default")
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
.insert("b_layout", "C", "B tensor data layout - Column by default")
.insert("c_layout", "R", "C tensor data layout - Row by default")
.insert("stride_a", "0", "Tensor A stride")
.insert("stride_q", "0", "Tensor AQ stride")
.insert("stride_b", "0", "Tensor B stride")
.insert("stride_c", "0", "Tensor C stride")
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
.insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
.insert("split_k", "1", "splitK value")
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
.insert("persistent", "0", "0:non-persistent, 1:persistent")
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
// host API
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);

View File

@@ -0,0 +1,179 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <tuple>
#include <stdexcept>
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/check_err.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm_quant.hpp"
// Forward declarations for quant type-specific implementations
template <ck_tile::QuantType QT>
struct QuantTypeTraits;
// Base class for common quant gemm functionality
template <typename Tuple, typename Derived>
class TestCkTileGemmQuantBase : public ::testing::Test
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using QDataType = std::tuple_element_t<5, Tuple>;
using CDataType = std::tuple_element_t<6, Tuple>;
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
using GemmConfig = std::tuple_element_t<8, Tuple>;
static constexpr uint32_t QuantGroupSize = std::tuple_element_t<9, Tuple>::value;
using AccDataType = float; // accumulate always in float
// Get the quant-type specific data types from traits
using QuantTraits = QuantTypeTraits<QuantType>;
using ComputeDataType = typename QuantTraits::template ComputeDataType<ADataType, BDataType>;
static constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
static constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
static constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
static constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
static constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
static constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
public:
void SetUp() override { static_cast<Derived*>(this)->SetUpQuantTypeSpecific(); }
void TearDown() override { static_cast<Derived*>(this)->TearDownQuantTypeSpecific(); }
// Common test execution logic
void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr bool kPreshuffle = false;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
kPreshuffle,
ALayout,
BLayout,
CLayout,
QuantType>;
// Let the derived class create the appropriate pipeline and epilogue
static_cast<Derived*>(this)
->template run_quant_gemm_impl<CodegenGemmShape, TilePartitioner, CodegenGemmTraits>(
args, s);
}
void RunTest(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
// Generate test data and run the kernel
static_cast<Derived*>(this)->run_test_with_validation(M, N, K);
}
// Helper function to check layout
template <typename Layout>
static constexpr auto is_row_major(Layout)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(Layout{})>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
// Tolerance calculation function for validation
template <typename ADataType_, typename BDataType_, typename AccDataType_, typename CDataType_>
auto calculate_rtol_atol(const ck_tile::index_t K,
const ck_tile::index_t kbatch,
const float max_accumulated_value)
{
using ComputeType =
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>;
// Calculate thresholds
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType_, AccDataType_>(
ck_tile::integer_divide_ceil(K, kbatch));
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType_, AccDataType_>(
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
// Calculate error due to split_k accumulation
const auto rtol_split_k =
ck_tile::get_relative_threshold<CDataType_, CDataType_, CDataType_>(kbatch);
const auto atol_split_k =
ck_tile::get_absolute_threshold<CDataType_, CDataType_, CDataType_>(
max_accumulated_value, kbatch);
// Use higher threshold
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
}
};
// Define generic QuantTypeTraits template (will be specialized)
template <ck_tile::QuantType QT>
struct QuantTypeTraits
{
static_assert(QT == ck_tile::QuantType::AQuantGrouped ||
QT == ck_tile::QuantType::BQuantGrouped ||
QT == ck_tile::QuantType::RowColQuant ||
QT == ck_tile::QuantType::TensorQuant,
"Unsupported quantization type");
};
// Specialization for AQuantGrouped
template <>
struct QuantTypeTraits<ck_tile::QuantType::AQuantGrouped>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = BDataType; // For AQuant, compute type is BDataType
static constexpr const char* name = "aquant";
};
// Specialization for BQuantGrouped
template <>
struct QuantTypeTraits<ck_tile::QuantType::BQuantGrouped>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = ADataType; // For BQuant, compute type is ADataType
static constexpr const char* name = "bquant";
};
// Specialization for RowColQuant
template <>
struct QuantTypeTraits<ck_tile::QuantType::RowColQuant>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = ADataType; // For RowColQuant, compute type is ADataType
static constexpr const char* name = "rowcol";
};
// Specialization for TensorQuant
template <>
struct QuantTypeTraits<ck_tile::QuantType::TensorQuant>
{
template <typename ADataType, typename BDataType>
using ComputeDataType = ADataType; // For TensorQuant, compute type is ADataType
static constexpr const char* name = "tensor";
};

View File

@@ -0,0 +1,919 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "test_gemm_quant_base.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
struct GemmConfigBase
{
static constexpr bool kPadM = false;
static constexpr bool kPadN = false;
static constexpr bool kPadK = false;
static constexpr bool PermuteA = false;
static constexpr bool PermuteB = false;
static constexpr bool TransposeC = false;
static constexpr bool UseStructuredSparsity = false;
static constexpr int kBlockPerCu = 1;
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
static constexpr ck_tile::index_t TileParitionerM01 = 4;
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
static constexpr ck_tile::index_t NumWaveGroups = 1;
static constexpr bool PreshuffleQuant = false;
static constexpr bool DoubleSmemBuffer = false;
// Default GEMM tile sizes for tests
static constexpr ck_tile::index_t M_Tile = 16;
static constexpr ck_tile::index_t N_Tile = 64;
static constexpr ck_tile::index_t K_Tile = 256;
static constexpr ck_tile::index_t M_Warp = 1;
static constexpr ck_tile::index_t N_Warp = 4;
static constexpr ck_tile::index_t K_Warp = 1;
static constexpr ck_tile::index_t M_Warp_Tile = 16;
static constexpr ck_tile::index_t N_Warp_Tile = 16;
static constexpr ck_tile::index_t K_Warp_Tile = 32;
};
template <typename Tuple>
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
// AQuant-specific data generation
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
// AQuant uses grouped quantization for A matrix
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
const ck_tile::index_t stride_AQ =
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{}));
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<QDataType> aq_m_aqk(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
// Initialize data with random values
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f}(a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f}(a_m_k);
}
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(aq_m_aqk);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> temp = a_m_k;
ck_tile::permute_vectors_i4x4_b(temp);
a_m_k_dev_buf.ToDevice(temp.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
nullptr, // bq_ptr (not used for AQuant)
1, // k_batch
M,
N,
K, // M, N, K
AQK, // QK_A
0, // QK_B (not used for AQuant)
stride_A,
stride_B,
stride_C,
stride_AQ,
0 // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference AQuant implementation
ck_tile::reference_gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "AQuantGrouped - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// AQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
using PipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
QDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for AQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};
// BQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
// BQuant uses grouped quantization for B matrix
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
const ck_tile::index_t stride_BQ = BQK;
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
ck_tile::HostTensor<QDataType> bq_bqk_n(
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, this->is_row_major(BLayout{})));
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_n);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<BDataType> temp = b_k_n;
ck_tile::permute_vectors_i4x4_b(temp);
b_k_n_dev_buf.ToDevice(temp.data());
}
else
{
b_k_n_dev_buf.ToDevice(b_k_n.data());
}
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
nullptr, // aq_ptr (not used for BQuant)
bq_bqk_n_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
1, // k_batch
M,
N,
K, // M, N, K
0, // QK_A (not used for BQuant)
BQK, // QK_B
stride_A,
stride_B,
stride_C,
0,
stride_BQ // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference BQuant implementation
ck_tile::reference_gemm_quant<ADataType,
QDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "BQuantGrouped - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// BQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using PipelineProblem =
ck_tile::GemmBQuantPipelineProblem<ADataType,
BDataType,
QDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
false, // transpose_c
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::BQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for BQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};
// RowColQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmRowColQuant
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmRowColQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmRowColQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
// RowColQuant uses per-row and per-column scales
const ck_tile::index_t stride_row_scales = 1;
const ck_tile::index_t stride_col_scales = 1;
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
ck_tile::HostTensor<QDataType> row_scales_m(ck_tile::host_tensor_descriptor(
M, 1, stride_row_scales, ck_tile::bool_constant<true>{}));
ck_tile::HostTensor<QDataType> col_scales_n(ck_tile::host_tensor_descriptor(
N, 1, stride_col_scales, ck_tile::bool_constant<true>{}));
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(row_scales_m);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(col_scales_n);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem row_scales_dev_buf(row_scales_m.get_element_space_size() *
sizeof(QDataType));
ck_tile::DeviceMem col_scales_dev_buf(col_scales_n.get_element_space_size() *
sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
row_scales_dev_buf.ToDevice(row_scales_m.data());
col_scales_dev_buf.ToDevice(col_scales_n.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
row_scales_dev_buf.GetDeviceBuffer(), // aq_ptr (row scales)
col_scales_dev_buf.GetDeviceBuffer(), // bq_ptr (col scales)
1, // k_batch
M,
N,
K, // M, N, K
1, // QK_A (row scales)
1, // QK_B (col scales)
stride_A,
stride_B,
stride_C,
stride_row_scales,
stride_col_scales // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference RowColQuant implementation
ck_tile::reference_gemm_rowcol_quant<ADataType,
QDataType,
BDataType,
QDataType,
AccDataType,
CDataType>(
a_m_k, row_scales_m, b_k_n, col_scales_n, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "RowColQuant validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "RowColQuant - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// RowColQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
ADataType,
BDataType,
AccDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::RowColQuant>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for RowColQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};
// TensorQuant-specific test fixture
template <typename Tuple>
class TestCkTileGemmTensorQuant
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmTensorQuant<Tuple>>
{
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmTensorQuant<Tuple>>;
friend Base;
public:
using typename Base::AccDataType;
using typename Base::ADataType;
using typename Base::ALayout;
using typename Base::BDataType;
using typename Base::BLayout;
using typename Base::CDataType;
using typename Base::CLayout;
using typename Base::ComputeDataType;
using typename Base::QDataType;
static constexpr auto QuantType = Base::QuantType;
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
protected:
void SetUpQuantTypeSpecific() {}
void TearDownQuantTypeSpecific() {}
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
{
const ck_tile::index_t stride_A = K;
const ck_tile::index_t stride_B = K;
const ck_tile::index_t stride_C = M;
// TensorQuant uses single scalar scale for each tensor
const ck_tile::index_t stride_scale_a = 1;
const ck_tile::index_t stride_scale_b = 1;
// Generate test data
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
ck_tile::HostTensor<QDataType> scale_a(
ck_tile::host_tensor_descriptor(1, 1, stride_scale_a, ck_tile::bool_constant<true>{}));
ck_tile::HostTensor<QDataType> scale_b(
ck_tile::host_tensor_descriptor(1, 1, stride_scale_b, ck_tile::bool_constant<true>{}));
// Initialize data with random values
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(scale_a);
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(scale_b);
// Allocate device memory
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
ck_tile::DeviceMem scale_a_dev_buf(scale_a.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem scale_b_dev_buf(scale_b.get_element_space_size() * sizeof(QDataType));
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
// Copy to device
a_m_k_dev_buf.ToDevice(a_m_k.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
scale_a_dev_buf.ToDevice(scale_a.data());
scale_b_dev_buf.ToDevice(scale_b.data());
// Create args for kernel execution
ck_tile::QuantGemmHostArgs args{
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
scale_a_dev_buf.GetDeviceBuffer(), // aq_ptr (scale A)
scale_b_dev_buf.GetDeviceBuffer(), // bq_ptr (scale B)
1, // k_batch
M,
N,
K, // M, N, K
1, // QK_A (tensor scale)
1, // QK_B (tensor scale)
stride_A,
stride_B,
stride_C,
stride_scale_a,
stride_scale_b // strides
};
// Run the kernel
ck_tile::stream_config stream_config{};
this->invoke_quant_gemm(args, stream_config);
// Validation using reference implementation
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
// Run reference TensorQuant implementation
ck_tile::reference_gemm_tensor_quant<ADataType,
QDataType,
BDataType,
QDataType,
AccDataType,
CDataType>(
a_m_k, scale_a, b_k_n, scale_b, c_m_n_host_ref);
// Get device result
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
// Calculate error tolerances
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol =
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, 1, max_accumulated_value);
// Validate results
bool pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
EXPECT_TRUE(pass) << "TensorQuant validation failed with M=" << M << ", N=" << N
<< ", K=" << K;
if(!pass)
{
std::cout << "TensorQuant - Relative error threshold: "
<< rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
}
private:
// TensorQuant-specific pipeline implementation
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
const ck_tile::stream_config& s)
{
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr bool transpose_c = false;
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
ADataType,
BDataType,
AccDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
transpose_c,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
Base::M_Warp,
Base::N_Warp,
Base::M_Warp_Tile,
Base::N_Warp_Tile,
Base::K_Warp_Tile,
transpose_c,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
GemmPipeline,
GemmEpilogue,
ck_tile::QuantType::TensorQuant>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Arguments not supported for TensorQuant kernel");
}
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
Kernel{}, grids, blocks, 0, kargs));
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
};

View File

@@ -0,0 +1,64 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include <gtest/gtest.h>
#include <memory>
#include "test_gemm_quant_fixtures.hpp"
// Type aliases for readability
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
using FP8 = ck_tile::fp8_t;
using BF8 = ck_tile::bf8_t;
using Half = ck_tile::half_t;
using PkInt4 = ck_tile::pk_int4_t;
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
using GroupSize = std::integral_constant<unsigned int, 128>;
// Type combinations for each quantization type
// clang-format off
using AQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
>;
// clang-format on
// clang-format off
using BQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
>;
// clang-format on
// clang-format off
using RowColQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, RowColQuant, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, RowColQuant, GemmConfigBase, GroupSize>
>;
// clang-format on
// clang-format off
using TensorQuantTypes = ::testing::Types<
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, TensorQuant, GemmConfigBase, GroupSize>,
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, TensorQuant, GemmConfigBase, GroupSize>
>;
// clang-format on
// Test suites for each quantization type
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes);
TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes);
#include "test_gemm_quant_ut_cases.inc"

View File

@@ -0,0 +1,28 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// AQuant tests
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// BQuant tests
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// RowColQuant tests
TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}
// TensorQuant tests
TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest)
{
this->run_test_with_validation(1024, 1024, 1024);
}

View File

@@ -1,616 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstring>
#include <iostream>
#include <ostream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <random>
#include "ck_tile/core/config.hpp"
#include "ck_tile/host.hpp"
#include "test_gemm_aquant_utils.hpp"
#include "ck_tile/host/permute_pk_int4.hpp"
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ComputeDataType,
typename ALayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
{
constexpr bool kPadM = false;
constexpr bool kPadN = false;
constexpr bool kPadK = false;
constexpr int kBlockPerCu = 1;
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
using CodegenGemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
kPadN,
kPadK,
false, // preshuffle
ALayout,
BLayout,
CLayout,
ck_tile::QuantType::AQuantGrouped>;
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
ComputeDataType>;
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
constexpr bool transposed_warp_gemm = false;
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using CodegenPipelineProblem =
ck_tile::GemmAQuantPipelineProblem<ADataType,
AQDataType,
BDataType,
AccDataType,
CodegenGemmShape,
CodegenGemmTraits,
QuantGroupSize,
transposed_warp_gemm,
ComputeDataType,
ck_tile::GemmPipelineScheduler::Intrawave,
has_hot_loop_v,
tail_number_v>;
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
ck_tile::tuple<>,
AccDataType,
CDataType,
ck_tile::tuple<>,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
transposed_warp_gemm,
ck_tile::memory_operation_enum::set>>;
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
CodegenGemmPipeline,
GemmEpilogue,
ck_tile::QuantType::AQuantGrouped>;
auto kargs = Kernel::MakeKernelArgs(args);
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
const dim3 blocks = Kernel::BlockSize();
if(args.k_batch != 1)
{
throw std::runtime_error("split-k is not supported yet!");
}
if(!Kernel::IsSupportedArgument(kargs))
{
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
}
if(s.log_level_ > 0)
{
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
<< "shape: " << CodegenGemmShape::GetName() << '\n'
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
<< std::endl;
}
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
return ave_time;
};
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
}
template <typename Layout>
static constexpr inline auto is_row_major(Layout layout_)
{
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
ck_tile::tensor_layout::gemm::RowMajor>>{};
}
template <typename GemmConfig,
typename ADataType,
typename AQDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout,
uint32_t QuantGroupSize>
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
ck_tile::DeviceMem& b_k_n_dev_buf,
ck_tile::DeviceMem& c_m_n_dev_buf,
ck_tile::index_t M,
ck_tile::index_t N,
ck_tile::index_t K,
ck_tile::index_t AQK,
ck_tile::index_t stride_A,
ck_tile::index_t stride_AQ,
ck_tile::index_t stride_B,
ck_tile::index_t stride_C,
ck_tile::index_t kbatch,
int n_warmup,
int n_repeat)
{
ck_tile::QuantGemmHostArgs args;
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
args.k_batch = kbatch;
args.M = M;
args.N = N;
args.K = K;
args.QK_A = AQK;
args.stride_A = stride_A;
args.stride_B = stride_B;
args.stride_C = stride_C;
args.stride_AQ = stride_AQ;
float ave_time = gemm_calc_aquant<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
BDataType,
ALayout,
BLayout,
CLayout,
QuantGroupSize>(
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
std::size_t flop = std::size_t(2) * M * N * K;
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK +
sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
<< " A_Type = " << DataTypeTraits<ADataType>::name
<< " AQ_Type = " << DataTypeTraits<AQDataType>::name
<< " B_Type = " << DataTypeTraits<BDataType>::name
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
return ave_time;
}
template <typename GemmConfig,
typename TypeConfig,
uint32_t QuantGroupSize,
typename ALayout,
typename AQLayout,
typename BLayout,
typename CLayout>
bool run_gemm_test_with_layouts(int argc,
char* argv[],
const ALayout a_layout = ALayout{},
const AQLayout aq_layout = AQLayout{},
const BLayout b_layout = BLayout{},
[[maybe_unused]] const CLayout c_layout = CLayout{})
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
using ADataType = typename TypeConfig::ADataType;
using AQDataType = typename TypeConfig::QDataType;
using BDataType = typename TypeConfig::BDataType;
using AccDataType = typename TypeConfig::AccDataType;
using CDataType = typename TypeConfig::CDataType;
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t K = arg_parser.get_int("k");
if(K % QuantGroupSize != 0)
{
throw std::runtime_error("K must be aligned with QuantGroupSize");
}
ck_tile::index_t AQK = K / QuantGroupSize;
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
int n_warmup = arg_parser.get_int("warmup");
int n_repeat = arg_parser.get_int("repeat");
ck_tile::index_t init_method = arg_parser.get_int("init");
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
ck_tile::HostTensor<ADataType> a_m_k(
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
ck_tile::HostTensor<AQDataType> aq_m_aqk(
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
ck_tile::HostTensor<BDataType> b_k_n(
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
if(init_method == 0)
{
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
a_m_k);
}
else
{
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
}
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
}
else if(init_method == 1)
{
std::cout << "Monotonic initialization is not supported." << std::endl;
return true;
}
else if(init_method == 2)
{
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_m_aqk);
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
}
else
{
a_m_k.SetZero();
aq_m_aqk.SetZero();
b_k_n.SetZero();
}
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes());
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
{
// Permute vector pk_i4x4 data for device implementation
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
}
else
{
a_m_k_dev_buf.ToDevice(a_m_k.data());
}
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
b_k_n_dev_buf.ToDevice(b_k_n.data());
c_m_n_dev_buf.SetZero();
c_m_n_dev_result.SetZero();
invoke_gemm<GemmConfig,
ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
ALayout,
AQLayout,
BLayout,
CLayout,
QuantGroupSize>(a_m_k_dev_buf,
aq_m_aqk_dev_buf,
b_k_n_dev_buf,
c_m_n_dev_buf,
M,
N,
K,
AQK,
stride_A,
stride_AQ,
stride_B,
stride_C,
kbatch,
n_warmup,
n_repeat);
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
bool pass = true;
if(arg_parser.get_int("v") == 1)
{
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
c_m_n_host_ref.SetZero();
ck_tile::reference_gemm_quant<ADataType,
AQDataType,
BDataType,
AccDataType,
CDataType,
QuantGroupSize,
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
const float max_accumulated_value =
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
K, kbatch, max_accumulated_value);
pass = ck_tile::check_err(c_m_n_dev_result,
c_m_n_host_ref,
"Error: Incorrect results!",
rtol_atol.at(ck_tile::number<0>{}),
rtol_atol.at(ck_tile::number<1>{}));
if(!pass)
{
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
<< std::endl;
}
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
}
else if(arg_parser.get_int("v") == 2)
{
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
return false;
}
return pass;
}
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
{
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
{
if(a_layout == "R" && b_layout == "C")
{
return run_gemm_test_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
argc, argv, Row{}, Row{}, Col{}, Row{});
}
else
{
throw std::runtime_error("Unsupported memory layout for the input matrices!");
}
}
else
{
throw std::runtime_error("Unsupported data type for A.");
}
return true;
}
template <template <typename PreType> typename GemmConfig>
bool run_gemm_test(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
std::string data_type = arg_parser.get_str("prec");
std::string a_layout = arg_parser.get_str("a_layout");
std::string b_layout = arg_parser.get_str("b_layout");
if(data_type == "fp8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "bf8")
{
using TypeConfig =
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
ck_tile::fp8_t>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
ck_tile::bf8_t>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32fp8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::fp8_t,
ck_tile::half_t,
float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else if(data_type == "i4f32bf8")
{
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
ck_tile::bf8_t,
ck_tile::half_t,
float>{});
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
a_layout, b_layout, argc, argv);
}
else
{
throw std::runtime_error("Unsupported data type for this operation !!!");
}
}
int run_gemm_combinations(std::string const& data_type)
{
// Define possible values for each parameter
std::vector<std::vector<std::string>> mnk_values = {{
"1",
"2048",
"5120",
},
{
"2",
"2048",
"5120",
},
{
"16",
"2048",
"5120",
},
{
"17",
"2048",
"5120",
},
{
"2047",
"5120",
"1024",
},
{
"2048",
"5120",
"1024",
}};
std::vector<std::string> prec_values = {data_type};
// We'll store all our arguments as strings first
std::vector<std::string> arg_strings = {"test_tile_gemm_aquant_basic",
"", // m placeholder
"", // n placeholder
"", // k placeholder
"", // prec placeholder
"-init=0",
"-v=1",
"-warmup=0",
"-repeat=1"};
// Create an array of const char pointers for argv
constexpr size_t ARG_COUNT = 9;
constexpr size_t ARG_MAX_LEN = 64;
char args[ARG_COUNT][ARG_MAX_LEN];
char* argv[ARG_COUNT];
// Run all combinations
bool is_success = true;
for(const auto& mnk : mnk_values)
{
arg_strings[1] = "-m=" + mnk[0];
arg_strings[2] = "-n=" + mnk[1];
arg_strings[3] = "-k=" + mnk[2];
for(const auto& prec : prec_values)
{
arg_strings[4] = "-prec=" + prec;
// Set up the argv array with pointers to the string data
for(size_t i = 0; i < ARG_COUNT; i++)
{
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
argv[i] = args[i];
}
std::cout << "Arguments received: ";
for(size_t i = 1; i < ARG_COUNT; ++i)
{
std::cout << argv[i] << " ";
}
std::cout << std::endl;
// Call the function with the current configuration
try
{
is_success = run_gemm_test<GemmConfigDecode>(ARG_COUNT, argv) && is_success;
}
catch(const ArgumentsNotSupportedException& e)
{
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
// ArgumentsNotSupportedException is not an error. Do not change is_success
}
catch(const std::runtime_error& e)
{
std::cerr << "Caught runtime error: " << e.what() << '\n';
is_success = false;
}
}
}
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
}