mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 20:40:07 +00:00
[CK_TILE] Tensor-wise scaled quant gemm kernel (#2846)
* rename gemm_group_quant to gemm_quant
* Add TensorWise quant mode
* Cshuffle epilogue tests with tensor scaling
* Add tensor quant to example
* Don't use readfirstlane for reading scales - doesn't work for some reason
* Add to changelog
* revert include - from a merge problem?
* revert common.hpp include
* revert host.hpp include
* remove unused utility function
* rename quant pipeline problem
* refactor quant tests
* remove aquant utils
* use TEST_F
* fix all tests by changing gemm config
* Use typed tests
* fix copyright
[ROCm/composable_kernel commit: 4363a82bd6]
This commit is contained in:
@@ -41,8 +41,8 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTest)
|
||||
NPerXdl,
|
||||
KPerXdl>;
|
||||
|
||||
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>();
|
||||
EXPECT_TRUE(result) << "Basic CShuffleEpilogue test failed";
|
||||
auto result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::None);
|
||||
EXPECT_FLOAT_EQ(result[0], 2.0F) << "Basic CShuffleEpilogue test failed";
|
||||
}
|
||||
|
||||
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
|
||||
@@ -73,8 +73,45 @@ TEST_F(CShuffleEpilogueTest, BasicHalfTestWithScale)
|
||||
NPerXdl,
|
||||
KPerXdl>;
|
||||
|
||||
bool result = run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(true);
|
||||
EXPECT_TRUE(result) << "Scale CShuffleEpilogue test failed";
|
||||
auto result =
|
||||
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::RowCol);
|
||||
EXPECT_FLOAT_EQ(result[0], 2.0F) << "RowCol CShuffleEpilogue test failed: first element not 2";
|
||||
EXPECT_FLOAT_EQ(result[1], 4.0F)
|
||||
<< "RowCol CShuffleEpilogue test failed: second element not 2*2";
|
||||
}
|
||||
|
||||
TEST_F(CShuffleEpilogueTest, BasicHalfTestWithTensorScale)
|
||||
{
|
||||
// Basic test configuration with half_t data types
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using ODataType = ck_tile::half_t;
|
||||
|
||||
constexpr index_t kMPerBlock = 256;
|
||||
constexpr index_t kNPerBlock = 256;
|
||||
constexpr index_t MWave = 2;
|
||||
constexpr index_t NWave = 2;
|
||||
constexpr index_t MPerXdl = 32;
|
||||
constexpr index_t NPerXdl = 32;
|
||||
constexpr index_t KPerXdl = 8;
|
||||
|
||||
using TestProblem = SimpleCShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
ODataType,
|
||||
kMPerBlock,
|
||||
kNPerBlock,
|
||||
MWave,
|
||||
NWave,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
KPerXdl>;
|
||||
|
||||
auto result =
|
||||
run_cshuffle_epilogue_test<TestProblem, kMPerBlock, kNPerBlock>(ScaleType::Tensor);
|
||||
EXPECT_FLOAT_EQ(result[0], 4.0F)
|
||||
<< "TensorScale CShuffleEpilogue test failed: first element not 2*2=4";
|
||||
}
|
||||
|
||||
int main(int argc, char** argv)
|
||||
|
||||
@@ -19,8 +19,15 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum class ScaleType
|
||||
{
|
||||
None,
|
||||
RowCol,
|
||||
Tensor
|
||||
};
|
||||
|
||||
// Simple test kernel to invoke the CShuffleEpilogue
|
||||
template <typename Problem, index_t M, index_t N, bool UseScale>
|
||||
template <typename Problem, index_t M, index_t N, ScaleType Scale>
|
||||
__global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __restrict__ output_data,
|
||||
float* m_scale,
|
||||
float* n_scale)
|
||||
@@ -61,7 +68,7 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
|
||||
auto empty_ds = make_tuple();
|
||||
|
||||
// Call the epilogue
|
||||
if constexpr(UseScale)
|
||||
if constexpr(Scale == ScaleType::RowCol)
|
||||
{
|
||||
const auto m_scale_window = make_tile_window(
|
||||
make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -75,6 +82,10 @@ __global__ void test_cshuffle_epilogue_kernel(typename Problem::ODataType* __res
|
||||
{0, 0});
|
||||
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, m_scale_window, n_scale_window);
|
||||
}
|
||||
else if constexpr(Scale == ScaleType::Tensor)
|
||||
{
|
||||
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem, *m_scale, *n_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
Epilogue{}(output_tile_window, acc_tile, empty_ds, smem);
|
||||
@@ -113,7 +124,7 @@ using SimpleCShuffleEpilogueProblem =
|
||||
memory_operation_enum::set>;
|
||||
|
||||
template <typename Problem, index_t M, index_t N>
|
||||
bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
auto run_cshuffle_epilogue_test(ScaleType scale = ScaleType::None)
|
||||
{
|
||||
using ODataType = typename Problem::ODataType;
|
||||
|
||||
@@ -142,7 +153,7 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
dim3 gridSize(1, 1, 1);
|
||||
dim3 blockSize(kBlockSize, 1, 1);
|
||||
|
||||
if(use_scale)
|
||||
if(scale == ScaleType::RowCol)
|
||||
{
|
||||
float* m_scale;
|
||||
float* n_scale;
|
||||
@@ -155,12 +166,25 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
hipMemcpy(m_scale, h_m_scale.data(), M * sizeof(float), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(
|
||||
hipMemcpy(n_scale, h_n_scale.data(), N * sizeof(float), hipMemcpyHostToDevice));
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, true>
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::RowCol>
|
||||
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
|
||||
}
|
||||
else if(scale == ScaleType::Tensor)
|
||||
{
|
||||
float* m_scale;
|
||||
float* n_scale;
|
||||
std::vector<float> h_m_scale(1, 2.0F);
|
||||
std::vector<float> h_n_scale(1, 1.0F);
|
||||
HIP_CHECK_ERROR(hipMalloc(&m_scale, sizeof(float)));
|
||||
HIP_CHECK_ERROR(hipMalloc(&n_scale, sizeof(float)));
|
||||
HIP_CHECK_ERROR(hipMemcpy(m_scale, h_m_scale.data(), sizeof(float), hipMemcpyHostToDevice));
|
||||
HIP_CHECK_ERROR(hipMemcpy(n_scale, h_n_scale.data(), sizeof(float), hipMemcpyHostToDevice));
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::Tensor>
|
||||
<<<gridSize, blockSize>>>(device_output, m_scale, n_scale);
|
||||
}
|
||||
else
|
||||
{
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, false>
|
||||
test_cshuffle_epilogue_kernel<Problem, M, N, ScaleType::None>
|
||||
<<<gridSize, blockSize>>>(device_output, nullptr, nullptr);
|
||||
}
|
||||
|
||||
@@ -172,20 +196,10 @@ bool run_cshuffle_epilogue_test(bool use_scale = false)
|
||||
HIP_CHECK_ERROR(hipMemcpy(
|
||||
host_output.data(), device_output, output_size * sizeof(ODataType), hipMemcpyDeviceToHost));
|
||||
|
||||
// Basic verification - just check that output has a 2, and 4 if using scaling
|
||||
bool has_2 =
|
||||
type_convert<float>(host_output[0]) > 1.9F && type_convert<float>(host_output[0]) < 2.1F;
|
||||
bool scale_has_4 = true;
|
||||
if(use_scale)
|
||||
{
|
||||
scale_has_4 = type_convert<float>(host_output[1]) > 3.9F &&
|
||||
type_convert<float>(host_output[1]) < 4.1F;
|
||||
}
|
||||
|
||||
// Cleanup
|
||||
HIP_CHECK_ERROR(hipFree(device_output));
|
||||
|
||||
return has_2 && scale_has_4;
|
||||
return host_output;
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -6,14 +6,9 @@ endif()
|
||||
list(APPEND TEST_GEMM_COMPILE_OPTIONS -mllvm -enable-noalias-to-md-conversion=0)
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
set(TEST_GEMM_NAME test_tile_gemm_aquant_basic)
|
||||
set(QUANT_TYPES fp8 bf8 i4fp8 i4bf8 i4f32fp8 i4f32bf8)
|
||||
|
||||
foreach(QUANT_TYPE ${QUANT_TYPES})
|
||||
add_gtest_executable(${TEST_GEMM_NAME}_${QUANT_TYPE} test_gemm_aquant_basic_${QUANT_TYPE}.cpp)
|
||||
target_compile_options(${TEST_GEMM_NAME}_${QUANT_TYPE} PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
endforeach()
|
||||
|
||||
# Typed Test Suite for GEMM Quantization
|
||||
add_gtest_executable(test_tile_gemm_quant_typed test_gemm_quant_typed.cpp)
|
||||
target_compile_options(test_tile_gemm_quant_typed PRIVATE ${TEST_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile quant gemm tests for current target")
|
||||
endif()
|
||||
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("fp8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("i4bf8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("i4f32bf8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("i4f32fp8"); }
|
||||
@@ -1,6 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_run_gemm_aquant_example.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("i4fp8"); }
|
||||
@@ -1,243 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_group_quant.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_PREFILL 1
|
||||
#define CK_TILE_PIPELINE_DECODE 2
|
||||
#define CK_TILE_PIPELINE_PRESHUFFLEQUANT 3
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
class ArgumentsNotSupportedException : public std::logic_error
|
||||
{
|
||||
public:
|
||||
explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {}
|
||||
};
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigDecode : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_DECODE;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPrefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PREFILL;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleQuant : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLEQUANT;
|
||||
static constexpr bool PreshuffleQuant = true;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_ = ADataType_,
|
||||
typename CDataType_ = ADataType_,
|
||||
typename QDataType_ = float>
|
||||
struct GemmQuantTypeConfig
|
||||
{
|
||||
using ADataType = ADataType_;
|
||||
using QDataType = QDataType_;
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = float;
|
||||
using CDataType = CDataType_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("aq_layout", "R", "Aq tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_q", "0", "Tensor AQ stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "i4fp8", "data type. fp8/bf8/i4fp8/i4bf8/i4f32fp8/i4f32bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent")
|
||||
.insert("as_br_cr", "false", "Choose between as_br_cr and as_bs_cr");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s);
|
||||
179
test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp
Normal file
179
test/ck_tile/gemm_block_scale/test_gemm_quant_base.hpp
Normal file
@@ -0,0 +1,179 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <stdexcept>
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/host/check_err.hpp"
|
||||
#include "ck_tile/host/reference/reference_gemm.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/gemm_quant.hpp"
|
||||
|
||||
// Forward declarations for quant type-specific implementations
|
||||
template <ck_tile::QuantType QT>
|
||||
struct QuantTypeTraits;
|
||||
|
||||
// Base class for common quant gemm functionality
|
||||
template <typename Tuple, typename Derived>
|
||||
class TestCkTileGemmQuantBase : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using BDataType = std::tuple_element_t<4, Tuple>;
|
||||
using QDataType = std::tuple_element_t<5, Tuple>;
|
||||
using CDataType = std::tuple_element_t<6, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<7, Tuple>::value;
|
||||
using GemmConfig = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr uint32_t QuantGroupSize = std::tuple_element_t<9, Tuple>::value;
|
||||
using AccDataType = float; // accumulate always in float
|
||||
|
||||
// Get the quant-type specific data types from traits
|
||||
using QuantTraits = QuantTypeTraits<QuantType>;
|
||||
using ComputeDataType = typename QuantTraits::template ComputeDataType<ADataType, BDataType>;
|
||||
|
||||
static constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
static constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
static constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
static constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
static constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
public:
|
||||
void SetUp() override { static_cast<Derived*>(this)->SetUpQuantTypeSpecific(); }
|
||||
|
||||
void TearDown() override { static_cast<Derived*>(this)->TearDownQuantTypeSpecific(); }
|
||||
|
||||
// Common test execution logic
|
||||
void invoke_quant_gemm(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
constexpr bool kPreshuffle = false;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
kPreshuffle,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantType>;
|
||||
|
||||
// Let the derived class create the appropriate pipeline and epilogue
|
||||
static_cast<Derived*>(this)
|
||||
->template run_quant_gemm_impl<CodegenGemmShape, TilePartitioner, CodegenGemmTraits>(
|
||||
args, s);
|
||||
}
|
||||
|
||||
void RunTest(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
// Generate test data and run the kernel
|
||||
static_cast<Derived*>(this)->run_test_with_validation(M, N, K);
|
||||
}
|
||||
|
||||
// Helper function to check layout
|
||||
template <typename Layout>
|
||||
static constexpr auto is_row_major(Layout)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(Layout{})>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
// Tolerance calculation function for validation
|
||||
template <typename ADataType_, typename BDataType_, typename AccDataType_, typename CDataType_>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType_) < sizeof(BDataType_), ADataType_, BDataType_>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType_, AccDataType_>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType_, AccDataType_>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType_, CDataType_, CDataType_>(kbatch);
|
||||
const auto atol_split_k =
|
||||
ck_tile::get_absolute_threshold<CDataType_, CDataType_, CDataType_>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
};
|
||||
|
||||
// Define generic QuantTypeTraits template (will be specialized)
|
||||
template <ck_tile::QuantType QT>
|
||||
struct QuantTypeTraits
|
||||
{
|
||||
static_assert(QT == ck_tile::QuantType::AQuantGrouped ||
|
||||
QT == ck_tile::QuantType::BQuantGrouped ||
|
||||
QT == ck_tile::QuantType::RowColQuant ||
|
||||
QT == ck_tile::QuantType::TensorQuant,
|
||||
"Unsupported quantization type");
|
||||
};
|
||||
|
||||
// Specialization for AQuantGrouped
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::AQuantGrouped>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = BDataType; // For AQuant, compute type is BDataType
|
||||
|
||||
static constexpr const char* name = "aquant";
|
||||
};
|
||||
|
||||
// Specialization for BQuantGrouped
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::BQuantGrouped>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = ADataType; // For BQuant, compute type is ADataType
|
||||
|
||||
static constexpr const char* name = "bquant";
|
||||
};
|
||||
|
||||
// Specialization for RowColQuant
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::RowColQuant>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = ADataType; // For RowColQuant, compute type is ADataType
|
||||
|
||||
static constexpr const char* name = "rowcol";
|
||||
};
|
||||
|
||||
// Specialization for TensorQuant
|
||||
template <>
|
||||
struct QuantTypeTraits<ck_tile::QuantType::TensorQuant>
|
||||
{
|
||||
template <typename ADataType, typename BDataType>
|
||||
using ComputeDataType = ADataType; // For TensorQuant, compute type is ADataType
|
||||
|
||||
static constexpr const char* name = "tensor";
|
||||
};
|
||||
919
test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
Normal file
919
test/ck_tile/gemm_block_scale/test_gemm_quant_fixtures.hpp
Normal file
@@ -0,0 +1,919 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "test_gemm_quant_base.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool PreshuffleQuant = false;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
// Default GEMM tile sizes for tests
|
||||
static constexpr ck_tile::index_t M_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Tile = 64;
|
||||
static constexpr ck_tile::index_t K_Tile = 256;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = 32;
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmAQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmAQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
// AQuant-specific data generation
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// AQuant uses grouped quantization for A matrix
|
||||
const ck_tile::index_t AQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
|
||||
const ck_tile::index_t stride_AQ =
|
||||
ck_tile::get_default_stride(M, AQK, 0, this->is_row_major(ALayout{}));
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<QDataType> aq_m_aqk(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.0f, 5.0f}(a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{-2.0f, 2.0f}(aq_m_aqk);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> temp = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(temp);
|
||||
a_m_k_dev_buf.ToDevice(temp.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
aq_m_aqk_dev_buf.GetDeviceBuffer(), // aq_ptr (scales)
|
||||
nullptr, // bq_ptr (not used for AQuant)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
AQK, // QK_A
|
||||
0, // QK_B (not used for AQuant)
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
stride_AQ,
|
||||
0 // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference AQuant implementation
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "AQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "AQuantGrouped - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// AQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for AQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
// BQuant-specific test fixture
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmBQuant : public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmBQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// BQuant uses grouped quantization for B matrix
|
||||
const ck_tile::index_t BQK = ck_tile::integer_divide_ceil(K, QuantGroupSize);
|
||||
const ck_tile::index_t stride_BQ = BQK;
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> bq_bqk_n(
|
||||
ck_tile::host_tensor_descriptor(BQK, N, stride_BQ, this->is_row_major(BLayout{})));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{0.f, 1.f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(bq_bqk_n);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem bq_bqk_n_dev_buf(bq_bqk_n.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> temp = b_k_n;
|
||||
ck_tile::permute_vectors_i4x4_b(temp);
|
||||
b_k_n_dev_buf.ToDevice(temp.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
bq_bqk_n_dev_buf.ToDevice(bq_bqk_n.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
nullptr, // aq_ptr (not used for BQuant)
|
||||
bq_bqk_n_dev_buf.GetDeviceBuffer(), // bq_ptr (scales)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
0, // QK_A (not used for BQuant)
|
||||
BQK, // QK_B
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
0,
|
||||
stride_BQ // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference BQuant implementation
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
false>(a_m_k, bq_bqk_n, b_k_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "BQuantGrouped validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "BQuantGrouped - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// BQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::GemmBQuantPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
QDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::BQuantGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
false, // transpose_c
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for BQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
// RowColQuant-specific test fixture
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmRowColQuant
|
||||
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmRowColQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmRowColQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// RowColQuant uses per-row and per-column scales
|
||||
const ck_tile::index_t stride_row_scales = 1;
|
||||
const ck_tile::index_t stride_col_scales = 1;
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> row_scales_m(ck_tile::host_tensor_descriptor(
|
||||
M, 1, stride_row_scales, ck_tile::bool_constant<true>{}));
|
||||
ck_tile::HostTensor<QDataType> col_scales_n(ck_tile::host_tensor_descriptor(
|
||||
N, 1, stride_col_scales, ck_tile::bool_constant<true>{}));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(row_scales_m);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(col_scales_n);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem row_scales_dev_buf(row_scales_m.get_element_space_size() *
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem col_scales_dev_buf(col_scales_n.get_element_space_size() *
|
||||
sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
row_scales_dev_buf.ToDevice(row_scales_m.data());
|
||||
col_scales_dev_buf.ToDevice(col_scales_n.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
row_scales_dev_buf.GetDeviceBuffer(), // aq_ptr (row scales)
|
||||
col_scales_dev_buf.GetDeviceBuffer(), // bq_ptr (col scales)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
1, // QK_A (row scales)
|
||||
1, // QK_B (col scales)
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
stride_row_scales,
|
||||
stride_col_scales // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference RowColQuant implementation
|
||||
ck_tile::reference_gemm_rowcol_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
QDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k, row_scales_m, b_k_n, col_scales_n, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "RowColQuant validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "RowColQuant - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// RowColQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::RowColQuant>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for RowColQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
|
||||
// TensorQuant-specific test fixture
|
||||
template <typename Tuple>
|
||||
class TestCkTileGemmTensorQuant
|
||||
: public TestCkTileGemmQuantBase<Tuple, TestCkTileGemmTensorQuant<Tuple>>
|
||||
{
|
||||
using Base = TestCkTileGemmQuantBase<Tuple, TestCkTileGemmTensorQuant<Tuple>>;
|
||||
friend Base;
|
||||
|
||||
public:
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::ALayout;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BLayout;
|
||||
using typename Base::CDataType;
|
||||
using typename Base::CLayout;
|
||||
using typename Base::ComputeDataType;
|
||||
using typename Base::QDataType;
|
||||
|
||||
static constexpr auto QuantType = Base::QuantType;
|
||||
static constexpr uint32_t QuantGroupSize = Base::QuantGroupSize;
|
||||
|
||||
protected:
|
||||
void SetUpQuantTypeSpecific() {}
|
||||
void TearDownQuantTypeSpecific() {}
|
||||
|
||||
void run_test_with_validation(ck_tile::index_t M, ck_tile::index_t N, ck_tile::index_t K)
|
||||
{
|
||||
const ck_tile::index_t stride_A = K;
|
||||
const ck_tile::index_t stride_B = K;
|
||||
const ck_tile::index_t stride_C = M;
|
||||
|
||||
// TensorQuant uses single scalar scale for each tensor
|
||||
const ck_tile::index_t stride_scale_a = 1;
|
||||
const ck_tile::index_t stride_scale_b = 1;
|
||||
|
||||
// Generate test data
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, this->is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, this->is_row_major(BLayout{})));
|
||||
ck_tile::HostTensor<QDataType> scale_a(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_scale_a, ck_tile::bool_constant<true>{}));
|
||||
ck_tile::HostTensor<QDataType> scale_b(
|
||||
ck_tile::host_tensor_descriptor(1, 1, stride_scale_b, ck_tile::bool_constant<true>{}));
|
||||
|
||||
// Initialize data with random values
|
||||
ck_tile::FillUniformDistribution<ADataType>{-0.5f, 0.5f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-0.5f, 0.5f}(b_k_n);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(scale_a);
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.001f, 0.01f}(scale_b);
|
||||
|
||||
// Allocate device memory
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size() * sizeof(ADataType));
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size() * sizeof(BDataType));
|
||||
ck_tile::DeviceMem scale_a_dev_buf(scale_a.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem scale_b_dev_buf(scale_b.get_element_space_size() * sizeof(QDataType));
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(M * N * sizeof(CDataType));
|
||||
|
||||
// Copy to device
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
scale_a_dev_buf.ToDevice(scale_a.data());
|
||||
scale_b_dev_buf.ToDevice(scale_b.data());
|
||||
|
||||
// Create args for kernel execution
|
||||
ck_tile::QuantGemmHostArgs args{
|
||||
a_m_k_dev_buf.GetDeviceBuffer(), // a_ptr
|
||||
b_k_n_dev_buf.GetDeviceBuffer(), // b_ptr
|
||||
c_m_n_dev_buf.GetDeviceBuffer(), // c_ptr
|
||||
scale_a_dev_buf.GetDeviceBuffer(), // aq_ptr (scale A)
|
||||
scale_b_dev_buf.GetDeviceBuffer(), // bq_ptr (scale B)
|
||||
1, // k_batch
|
||||
M,
|
||||
N,
|
||||
K, // M, N, K
|
||||
1, // QK_A (tensor scale)
|
||||
1, // QK_B (tensor scale)
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
stride_scale_a,
|
||||
stride_scale_b // strides
|
||||
};
|
||||
|
||||
// Run the kernel
|
||||
ck_tile::stream_config stream_config{};
|
||||
this->invoke_quant_gemm(args, stream_config);
|
||||
|
||||
// Validation using reference implementation
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
// Run reference TensorQuant implementation
|
||||
ck_tile::reference_gemm_tensor_quant<ADataType,
|
||||
QDataType,
|
||||
BDataType,
|
||||
QDataType,
|
||||
AccDataType,
|
||||
CDataType>(
|
||||
a_m_k, scale_a, b_k_n, scale_b, c_m_n_host_ref);
|
||||
|
||||
// Get device result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, this->is_row_major(CLayout{})));
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.mData.data());
|
||||
|
||||
// Calculate error tolerances
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol =
|
||||
this->template calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, 1, max_accumulated_value);
|
||||
|
||||
// Validate results
|
||||
bool pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
EXPECT_TRUE(pass) << "TensorQuant validation failed with M=" << M << ", N=" << N
|
||||
<< ", K=" << K;
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "TensorQuant - Relative error threshold: "
|
||||
<< rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
// TensorQuant-specific pipeline implementation
|
||||
template <typename CodegenGemmShape, typename TilePartitioner, typename CodegenGemmTraits>
|
||||
void run_quant_gemm_impl(const ck_tile::QuantGemmHostArgs& args,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + Base::K_Tile - 1) / Base::K_Tile * Base::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr bool transpose_c = false;
|
||||
|
||||
using PipelineProblem = ck_tile::GemmRowColTensorQuantPipelineProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
transpose_c,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
Base::M_Warp,
|
||||
Base::N_Warp,
|
||||
Base::M_Warp_Tile,
|
||||
Base::N_Warp_Tile,
|
||||
Base::K_Warp_Tile,
|
||||
transpose_c,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
GemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::TensorQuant>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Arguments not supported for TensorQuant kernel");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfigBase::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
};
|
||||
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
};
|
||||
64
test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp
Normal file
64
test/ck_tile/gemm_block_scale/test_gemm_quant_typed.cpp
Normal file
@@ -0,0 +1,64 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
|
||||
#include "test_gemm_quant_fixtures.hpp"
|
||||
|
||||
// Type aliases for readability
|
||||
using RowMajor = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using ColumnMajor = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Half = ck_tile::half_t;
|
||||
using PkInt4 = ck_tile::pk_int4_t;
|
||||
using AQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::AQuantGrouped>;
|
||||
using BQuantGrouped = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
using GroupSize = std::integral_constant<unsigned int, 128>;
|
||||
|
||||
// Type combinations for each quantization type
|
||||
// clang-format off
|
||||
using AQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, FP8, FP8, Half, AQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, PkInt4, BF8, BF8, Half, AQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using BQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, PkInt4, FP8, Half, BQuantGrouped, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, PkInt4, BF8, Half, BQuantGrouped, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using RowColQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, RowColQuant, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, RowColQuant, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// clang-format off
|
||||
using TensorQuantTypes = ::testing::Types<
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, FP8, FP8, float, Half, TensorQuant, GemmConfigBase, GroupSize>,
|
||||
std::tuple<RowMajor, ColumnMajor, RowMajor, BF8, BF8, float, Half, TensorQuant, GemmConfigBase, GroupSize>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
// Test suites for each quantization type
|
||||
TYPED_TEST_SUITE(TestCkTileGemmAQuant, AQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmBQuant, BQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmRowColQuant, RowColQuantTypes);
|
||||
TYPED_TEST_SUITE(TestCkTileGemmTensorQuant, TensorQuantTypes);
|
||||
|
||||
#include "test_gemm_quant_ut_cases.inc"
|
||||
28
test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc
Normal file
28
test/ck_tile/gemm_block_scale/test_gemm_quant_ut_cases.inc
Normal file
@@ -0,0 +1,28 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
// AQuant tests
|
||||
TYPED_TEST(TestCkTileGemmAQuant, AQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// BQuant tests
|
||||
TYPED_TEST(TestCkTileGemmBQuant, BQuantGroupedTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// RowColQuant tests
|
||||
TYPED_TEST(TestCkTileGemmRowColQuant, RowColQuantTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
|
||||
// TensorQuant tests
|
||||
TYPED_TEST(TestCkTileGemmTensorQuant, TensorQuantTest)
|
||||
{
|
||||
this->run_test_with_validation(1024, 1024, 1024);
|
||||
}
|
||||
@@ -1,616 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <random>
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_aquant_utils.hpp"
|
||||
#include "ck_tile/host/permute_pk_int4.hpp"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ComputeDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float gemm_calc_aquant(const ck_tile::QuantGemmHostArgs& args, const ck_tile::stream_config& s)
|
||||
{
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
static_assert(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
|
||||
constexpr ck_tile::index_t M_Tile = GemmConfig::M_Tile;
|
||||
constexpr ck_tile::index_t N_Tile = GemmConfig::N_Tile;
|
||||
constexpr ck_tile::index_t K_Tile = GemmConfig::K_Tile;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = GemmConfig::M_Warp;
|
||||
constexpr ck_tile::index_t N_Warp = GemmConfig::N_Warp;
|
||||
constexpr ck_tile::index_t K_Warp = GemmConfig::K_Warp;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = GemmConfig::M_Warp_Tile;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = GemmConfig::N_Warp_Tile;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = GemmConfig::K_Warp_Tile;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits = ck_tile::TileGemmQuantTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
false, // preshuffle
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
using GemmPipelineProblem = ck_tile::GemmPipelineProblemBase<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
ComputeDataType>;
|
||||
|
||||
using BaseGemmPipeline = ck_tile::BaseAQuantGemmPipelineAgBgCrCompV3<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t K_split = (args.K + K_Tile - 1) / K_Tile * K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
constexpr bool transposed_warp_gemm = false;
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
|
||||
using CodegenPipelineProblem =
|
||||
ck_tile::GemmAQuantPipelineProblem<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CodegenGemmShape,
|
||||
CodegenGemmTraits,
|
||||
QuantGroupSize,
|
||||
transposed_warp_gemm,
|
||||
ComputeDataType,
|
||||
ck_tile::GemmPipelineScheduler::Intrawave,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using CodegenGemmPipeline = ck_tile::AQuantGemmPipelineAgBgCrCompV3<CodegenPipelineProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
transposed_warp_gemm,
|
||||
ck_tile::memory_operation_enum::set>>;
|
||||
using Kernel = ck_tile::QuantGemmKernel<TilePartitioner,
|
||||
CodegenGemmPipeline,
|
||||
GemmEpilogue,
|
||||
ck_tile::QuantType::AQuantGrouped>;
|
||||
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
{
|
||||
throw std::runtime_error("split-k is not supported yet!");
|
||||
}
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw std::runtime_error("Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
return BaseGemmPipeline::TailHandler(Run, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename AQDataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
uint32_t QuantGroupSize>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& aq_m_aqk_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t AQK,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_AQ,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat)
|
||||
{
|
||||
ck_tile::QuantGemmHostArgs args;
|
||||
args.a_ptr = a_m_k_dev_buf.GetDeviceBuffer();
|
||||
args.aq_ptr = aq_m_aqk_dev_buf.GetDeviceBuffer();
|
||||
args.b_ptr = b_k_n_dev_buf.GetDeviceBuffer();
|
||||
args.c_ptr = c_m_n_dev_buf.GetDeviceBuffer();
|
||||
args.k_batch = kbatch;
|
||||
args.M = M;
|
||||
args.N = N;
|
||||
args.K = K;
|
||||
args.QK_A = AQK;
|
||||
args.stride_A = stride_A;
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
args.stride_AQ = stride_AQ;
|
||||
|
||||
float ave_time = gemm_calc_aquant<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
BDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte = sizeof(ADataType) * M * K + sizeof(AQDataType) * M * AQK +
|
||||
sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideAQ =" << stride_AQ << " StrideB =" << stride_B
|
||||
<< " StrideC =" << stride_C << " A_Layout =" << ALayout::name
|
||||
<< " B_Layout =" << BLayout::name << " C_Layout =" << CLayout::name
|
||||
<< " A_Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " AQ_Type = " << DataTypeTraits<AQDataType>::name
|
||||
<< " B_Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " Acc_Type = " << DataTypeTraits<AccDataType>::name
|
||||
<< " C_Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename TypeConfig,
|
||||
uint32_t QuantGroupSize,
|
||||
typename ALayout,
|
||||
typename AQLayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool run_gemm_test_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const AQLayout aq_layout = AQLayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using AQDataType = typename TypeConfig::QDataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using AccDataType = typename TypeConfig::AccDataType;
|
||||
using CDataType = typename TypeConfig::CDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
if(K % QuantGroupSize != 0)
|
||||
{
|
||||
throw std::runtime_error("K must be aligned with QuantGroupSize");
|
||||
}
|
||||
|
||||
ck_tile::index_t AQK = K / QuantGroupSize;
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_AQ = arg_parser.get_int("stride_q");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_AQ = ck_tile::get_default_stride(M, AQK, stride_AQ, is_row_major(aq_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<AQDataType> aq_m_aqk(
|
||||
ck_tile::host_tensor_descriptor(M, AQK, stride_AQ, is_row_major(aq_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
std::random_device rd;
|
||||
std::mt19937 gen(rd());
|
||||
std::uniform_int_distribution<std::uint32_t> fill_seed(0, 500);
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ck_tile::pk_int4_t>{-5.0f, 5.0f, fill_seed(gen)}(
|
||||
a_m_k);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-2.0f, 3.0f, fill_seed(gen)}(a_m_k);
|
||||
}
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-2.0f, 2.0f, fill_seed(gen)}(aq_m_aqk);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.0f, 5.0f, fill_seed(gen)}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
std::cout << "Monotonic initialization is not supported." << std::endl;
|
||||
return true;
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillConstant<ADataType>{static_cast<ADataType>(0x22)}(a_m_k);
|
||||
ck_tile::FillConstant<AQDataType>{static_cast<AQDataType>(0.5f)}(aq_m_aqk);
|
||||
ck_tile::FillConstant<BDataType>{static_cast<BDataType>(0x38)}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
aq_m_aqk.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem aq_m_aqk_dev_buf(aq_m_aqk.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
if constexpr(std::is_same_v<ADataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<ADataType> a_m_k_dev = a_m_k;
|
||||
ck_tile::permute_vectors_i4x4_b(a_m_k_dev);
|
||||
a_m_k_dev_buf.ToDevice(a_m_k_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
}
|
||||
aq_m_aqk_dev_buf.ToDevice(aq_m_aqk.data());
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
AQLayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
QuantGroupSize>(a_m_k_dev_buf,
|
||||
aq_m_aqk_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
AQK,
|
||||
stride_A,
|
||||
stride_AQ,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm_quant<ADataType,
|
||||
AQDataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
QuantGroupSize,
|
||||
true>(a_m_k, aq_m_aqk, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
if(!pass)
|
||||
{
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
}
|
||||
std::cout << "CPU verification " << (pass ? "Passed!" : "Failed ...") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
std::cout << "GPU verification is not implemented yet. Re-run with -v=1" << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename TypeConfig, uint32_t QuantGroupSize>
|
||||
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<typename TypeConfig::ADataType, ck_tile::pk_int4_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::fp8_t> ||
|
||||
std::is_same_v<typename TypeConfig::ADataType, ck_tile::bf8_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, TypeConfig, QuantGroupSize>(
|
||||
argc, argv, Row{}, Row{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for A.");
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
using TypeConfig =
|
||||
decltype(GemmQuantTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t, float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::fp8_t>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
ck_tile::bf8_t>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32fp8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "i4f32bf8")
|
||||
{
|
||||
using TypeConfig = decltype(GemmQuantTypeConfig<ck_tile::pk_int4_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t,
|
||||
float>{});
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::pk_int4_t>, TypeConfig, 128>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_combinations(std::string const& data_type)
|
||||
{
|
||||
// Define possible values for each parameter
|
||||
std::vector<std::vector<std::string>> mnk_values = {{
|
||||
"1",
|
||||
"2048",
|
||||
"5120",
|
||||
},
|
||||
{
|
||||
"2",
|
||||
"2048",
|
||||
"5120",
|
||||
},
|
||||
{
|
||||
"16",
|
||||
"2048",
|
||||
"5120",
|
||||
},
|
||||
{
|
||||
"17",
|
||||
"2048",
|
||||
"5120",
|
||||
},
|
||||
{
|
||||
"2047",
|
||||
"5120",
|
||||
"1024",
|
||||
},
|
||||
{
|
||||
"2048",
|
||||
"5120",
|
||||
"1024",
|
||||
}};
|
||||
std::vector<std::string> prec_values = {data_type};
|
||||
|
||||
// We'll store all our arguments as strings first
|
||||
std::vector<std::string> arg_strings = {"test_tile_gemm_aquant_basic",
|
||||
"", // m placeholder
|
||||
"", // n placeholder
|
||||
"", // k placeholder
|
||||
"", // prec placeholder
|
||||
"-init=0",
|
||||
"-v=1",
|
||||
"-warmup=0",
|
||||
"-repeat=1"};
|
||||
|
||||
// Create an array of const char pointers for argv
|
||||
constexpr size_t ARG_COUNT = 9;
|
||||
constexpr size_t ARG_MAX_LEN = 64;
|
||||
char args[ARG_COUNT][ARG_MAX_LEN];
|
||||
char* argv[ARG_COUNT];
|
||||
|
||||
// Run all combinations
|
||||
bool is_success = true;
|
||||
for(const auto& mnk : mnk_values)
|
||||
{
|
||||
arg_strings[1] = "-m=" + mnk[0];
|
||||
arg_strings[2] = "-n=" + mnk[1];
|
||||
arg_strings[3] = "-k=" + mnk[2];
|
||||
|
||||
for(const auto& prec : prec_values)
|
||||
{
|
||||
arg_strings[4] = "-prec=" + prec;
|
||||
|
||||
// Set up the argv array with pointers to the string data
|
||||
for(size_t i = 0; i < ARG_COUNT; i++)
|
||||
{
|
||||
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
|
||||
argv[i] = args[i];
|
||||
}
|
||||
|
||||
std::cout << "Arguments received: ";
|
||||
for(size_t i = 1; i < ARG_COUNT; ++i)
|
||||
{
|
||||
std::cout << argv[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test<GemmConfigDecode>(ARG_COUNT, argv) && is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
|
||||
// ArgumentsNotSupportedException is not an error. Do not change is_success
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
is_success = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
|
||||
}
|
||||
Reference in New Issue
Block a user