mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
test(grouped_gemm): add unit tests for grouped_gemm bquant with preshuffleB true (#3119)
* add tensorwise quant in grouped gemm
* fix example issue
* update test cases
* format codes
* clang format
* use GTEST_FAIL
* add bquant to grouped_gemm
* add tensorwise quant in grouped gemm
* fix example issue
* update test cases
* format codes
* clang format
* use GTEST_FAIL
* fix a bug in test_grouped_gemm_util
* skip test when use wmma on grouped_quant kernel
* change cmake
* fix a bug in test_grouped_gemm_util
* skip test when use wmma on grouped_quant kernel
* change cmake
* tests(quant_grouped_gemm): add unit tests to cover bquant in grouped_gemm
* Update test/ck_tile/grouped_gemm_quant/test_grouped_gemm_util_quant.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* Update example/ck_tile/17_grouped_gemm/quant_grouped_gemm.hpp
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
* feat: add bf8 support
* chore: remove unnecessary decltype usage
* chore: add default quant_mode to function signature as fallback
* fix: pass correct runtime pipeline params in grouped_gemm bquant kernel
Calculate has_hot_loop, num_loop, and tail_number on device side for each
GEMM problem instead of using default values. This fixes incorrect results
when different problems in the group have different K dimensions.
* chore: set default quant mode in function signature
* test: add additional test cases to cover edge case of no hotloop
* change code based on comments
* WIP: bquant preshuffle b compiles but gives numerical error
* feat(grouped_gemm_quant): bquant with preshuffleB support added to grouped_gemm example & kernel
* refactor: refactor code after merge commit
* chore: remove print statements
* test(grouped_gemm): split test cases by quant mode to reduce compilation time and add bquant-preshuffleB mode test cases
---------
Co-authored-by: kyle-256 <Kyle.Zhao@amd.com>
Co-authored-by: ThomasNing <thomas.ning@amd.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
[ROCm/composable_kernel commit: 8f1274d9b6]
This commit is contained in:
@@ -49,7 +49,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
false, // PreshuffleQuant
|
||||
false, // PreshuffleB
|
||||
GemmConfig::PreshuffleB, // PreshuffleB
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -58,7 +58,7 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BQLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
true>;
|
||||
true>; // Persistence
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
@@ -86,10 +86,14 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
|
||||
using GemmPipeline =
|
||||
typename std::conditional<QuantMode == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantMode == ck_tile::QuantType::RowColQuant ||
|
||||
QuantMode == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<GemmConfig::PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -141,5 +145,6 @@ float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
return !run_grouped_gemm_example<GemmConfigComputeV3_2>(argc, argv);
|
||||
int result1 = !run_grouped_gemm_example<GemmConfigPreshuffleB_Bquant_prefill>(argc, argv);
|
||||
return result1;
|
||||
}
|
||||
|
||||
@@ -10,9 +10,6 @@
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_BQUANT_COMPUTE_V3 2
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
@@ -31,6 +28,22 @@ constexpr ck_tile::index_t get_k_warp_tile()
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
@@ -67,8 +80,9 @@ struct GemmConfigBase
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr bool PreshuffleB = false;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
@@ -85,10 +99,26 @@ struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
};
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
template <typename PrecType>
|
||||
struct GemmConfigPreshuffleB_Bquant_prefill : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 4;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile =
|
||||
get_k_from_preshuffled_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool PreshuffleB = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
@@ -118,7 +148,8 @@ auto create_args(int argc, char* argv[])
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel.")
|
||||
.insert("group_count", "8", "group count.")
|
||||
.insert("kbatch", "1", "kbatch for SplitK")
|
||||
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol");
|
||||
.insert("quant_mode", "bquant", "Choose bquant (default), tensor, or rowcol")
|
||||
.insert("init", "0", "0. Random, 2. One(s) (Constant)");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
|
||||
@@ -163,6 +163,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
const int repeat = arg_parser.get_int("repeat");
|
||||
const int warmup = arg_parser.get_int("warmup");
|
||||
const int kbatch = arg_parser.get_int("kbatch");
|
||||
const int init_method = arg_parser.get_int("init");
|
||||
bool validate = arg_parser.get_bool("validate");
|
||||
const ck_tile::index_t QuantGroupSize = 128;
|
||||
|
||||
@@ -203,6 +204,7 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
|
||||
Ms.push_back(256 + 256 * i);
|
||||
Ns.push_back(256 + 512 * i);
|
||||
Ks.push_back(512 + 128 * i);
|
||||
@@ -280,6 +282,12 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
stride_AQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
stride_BQs[i] = 1; // Tensor quantization: tensor shape [1]
|
||||
}
|
||||
else if constexpr(QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
stride_AQs[i] = 0; // No A quantization
|
||||
stride_BQs[i] =
|
||||
ck_tile::get_default_stride(BQK, N, stride_BQs[i], is_row_major(bq_layout));
|
||||
}
|
||||
|
||||
a_m_k_tensors.push_back(ck_tile::HostTensor<ADataType>(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_As[i], is_row_major(a_layout))));
|
||||
@@ -313,10 +321,20 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
<< " b_k_n: " << b_k_n_tensors[i].mDesc << " c_m_n: " << c_m_n_tensors[i].mDesc
|
||||
<< " aq: " << aq_tensors[i].mDesc << " bq: " << bq_tensors[i].mDesc << std::endl;
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-1.f, 1.f}(a_m_k_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-1.f, 1.f}(b_k_n_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<AQDataType>{-1.f, 1.f}(aq_tensors[i]);
|
||||
ck_tile::FillUniformDistribution<BQDataType>{-1.f, 1.f}(bq_tensors[i]);
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.push_back(std::make_unique<ck_tile::DeviceMem>(
|
||||
a_m_k_tensors[i].get_element_space_size_in_bytes()));
|
||||
@@ -329,8 +347,18 @@ int run_grouped_gemm_example_with_layouts(int argc,
|
||||
bq_dev_buf.push_back(
|
||||
std::make_unique<ck_tile::DeviceMem>(bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
if constexpr(GemmConfig::PreshuffleB && QuantMode == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
ck_tile::HostTensor<BDataType> b_shuffle_host =
|
||||
ck_tile::shuffle_b<GemmConfig>(b_k_n_tensors[i]);
|
||||
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
|
||||
@@ -20,7 +20,7 @@ struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
@@ -483,6 +483,7 @@ struct QuantGemmKernel
|
||||
const QuantGemmKernelArgs& kargs,
|
||||
const SplitKBatchOffset& splitk_batch_offset)
|
||||
{
|
||||
|
||||
static_assert(!TilePartitioner::BlockGemmShape::PermuteA, "Not implemented!");
|
||||
const auto& a_tensor_view = [&]() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
@@ -790,6 +791,7 @@ struct QuantGemmKernel
|
||||
}();
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
|
||||
return make_tuple(a_pad_view, aq_pad_view, b_flat_view, bq_pad_view, c_pad_view);
|
||||
}
|
||||
else
|
||||
@@ -802,6 +804,7 @@ struct QuantGemmKernel
|
||||
CK_TILE_DEVICE static auto
|
||||
MakeGemmTileWindows(const PadView& views, const index_t i_m, const index_t i_n)
|
||||
{
|
||||
|
||||
const auto& a_pad_view = views.at(I0);
|
||||
const auto& aq_pad_view = views.at(I1);
|
||||
const auto& b_pad_view = views.at(I2);
|
||||
@@ -867,6 +870,7 @@ struct QuantGemmKernel
|
||||
const auto& b_block_window = [&]() {
|
||||
if constexpr(PreshuffleB)
|
||||
{
|
||||
|
||||
return make_tile_window(
|
||||
b_pad_view,
|
||||
make_tuple(number<GemmPipeline::flatNPerWarp>{},
|
||||
|
||||
@@ -317,13 +317,88 @@ struct QuantGroupedGemmKernel
|
||||
const BQDataType* bq_ptr = static_cast<const BQDataType*>(kargs.bq_ptr);
|
||||
CDataType* c_ptr = static_cast<CDataType*>(kargs.c_ptr);
|
||||
|
||||
static_assert(GemmPipeline::DoubleSmemBuffer == false,
|
||||
"DoubleSmemBuffer needs to be false");
|
||||
// allocate LDS
|
||||
__shared__ char smem_ptr_0[GetSmemSize()];
|
||||
|
||||
RunGemmWithPipelineSelection(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
|
||||
// Only for BQuantGrouped DoubleSmemBuffer is supported
|
||||
if constexpr(GemmPipeline::DoubleSmemBuffer == true &&
|
||||
kQuantType == QuantType::BQuantGrouped)
|
||||
{
|
||||
|
||||
__shared__ char smem_ptr_1[GetSmemSize()];
|
||||
RunGemmWithPipelineSelection2LDS(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
RunGemmWithPipelineSelection(a_ptr,
|
||||
b_ptr,
|
||||
aq_ptr,
|
||||
bq_ptr,
|
||||
c_ptr,
|
||||
smem_ptr_0,
|
||||
kargs,
|
||||
splitk_batch_offset,
|
||||
i_m,
|
||||
i_n);
|
||||
}
|
||||
}
|
||||
|
||||
template <memory_operation_enum DstInMemOp = memory_operation_enum::set>
|
||||
CK_TILE_DEVICE static void
|
||||
RunGemmWithPipelineSelection2LDS(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const AQDataType* aq_ptr,
|
||||
const BQDataType* bq_ptr,
|
||||
CDataType* c_ptr,
|
||||
void* smem_ptr_0,
|
||||
void* smem_ptr_1,
|
||||
const QuantGroupedGemmKernelArgs& kargs,
|
||||
const typename Base::SplitKBatchOffset& splitk_batch_offset,
|
||||
const index_t block_idx_m,
|
||||
const index_t block_idx_n)
|
||||
{
|
||||
static_assert(kQuantType == QuantType::BQuantGrouped, "kQuantType must be BQuantGrouped");
|
||||
// Create Gemm tensor views, pad views and tile windows
|
||||
const auto& gemm_tensor_views_tuple =
|
||||
Base::template MakeGemmTensorViews<EpiloguePipeline::MemoryOperation>(
|
||||
a_ptr, b_ptr, aq_ptr, bq_ptr, c_ptr, kargs, splitk_batch_offset);
|
||||
|
||||
const auto& gemm_pad_views = Base::MakeGemmPadViews(gemm_tensor_views_tuple);
|
||||
auto gemm_tile_windows =
|
||||
Base::MakeGemmTileWindows(gemm_pad_views, block_idx_m, block_idx_n);
|
||||
|
||||
const index_t num_loop = __builtin_amdgcn_readfirstlane(
|
||||
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
|
||||
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
// Run GEMM cooperatively by whole workgroup.
|
||||
const auto& a_block_window = gemm_tile_windows.at(Base::I0);
|
||||
const auto& b_block_window = gemm_tile_windows.at(Base::I2);
|
||||
|
||||
const auto& bq_block_window = gemm_tile_windows.at(Base::I3);
|
||||
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window,
|
||||
b_block_window,
|
||||
bq_block_window,
|
||||
num_loop,
|
||||
tail_num,
|
||||
smem_ptr_0,
|
||||
smem_ptr_1);
|
||||
|
||||
// Run Epilogue Pipeline
|
||||
auto& c_block_window = gemm_tile_windows.at(Base::I4);
|
||||
|
||||
EpiloguePipeline{}(c_block_window, c_block_tile, c_block_window, smem_ptr_0);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -458,6 +458,7 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
|
||||
return operator()<TailNum>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
@@ -467,5 +468,31 @@ struct WPQuantBPipelineAgBgCrV2 : public WeightPreshufflePipelineAGmemBGmemCRegV
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename BQDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem_ping,
|
||||
void* p_smem_pong) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto bool_val, auto tail_num_) {
|
||||
(void)bool_val; // Suppress unused parameter warning
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
return operator()<tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
[](const ADataType& a) { return a; },
|
||||
b_flat_dram_block_window_tmp,
|
||||
bq_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem_ping,
|
||||
p_smem_pong);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, true, tail_number);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -4,7 +4,14 @@ if(CK_USE_OCP_FP8)
|
||||
endif()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx94|gfx95")
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant test_grouped_gemm_quant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
# Split into three separate test executables for faster parallel compilation
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_rowcol test_grouped_gemm_quant_rowcol.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_rowcol PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_tensor test_grouped_gemm_quant_tensor.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_tensor PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_gtest_executable(test_ck_tile_grouped_gemm_quant_bquant test_grouped_gemm_quant_bquant.cpp)
|
||||
target_compile_options(test_ck_tile_grouped_gemm_quant_bquant PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
|
||||
@@ -22,26 +22,28 @@ using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantTyp
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant>,
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant>
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, BF8, F32, BF8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, BF8, F32, BF8, F32, F32, F16, BQuant, True>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using BQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::BQuantGrouped>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_BQuant = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, False>,
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, BQuant, True>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_BQuant, KernelTypes_BQuant);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_BQuant
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using RowColQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::RowColQuant>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_RowCol = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, RowColQuant, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_RowCol, KernelTypes_RowCol);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_RowCol
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_util_quant.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F32 = float;
|
||||
using FP8 = ck_tile::fp8_t;
|
||||
using BF8 = ck_tile::bf8_t;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using True = ck_tile::bool_constant<true>;
|
||||
using False = ck_tile::bool_constant<false>;
|
||||
using TensorQuant = std::integral_constant<ck_tile::QuantType, ck_tile::QuantType::TensorQuant>;
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes_Tensor = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, AQDataType, BDataType, BQDataType, AccDataType, CDataType, QuantType, PreshuffleB
|
||||
std::tuple< Row, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Col, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Row, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>,
|
||||
std::tuple< Col, Row, Row, FP8, F32, FP8, F32, F32, F16, TensorQuant, False>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileGroupedGemmQuant_Tensor, KernelTypes_Tensor);
|
||||
|
||||
#define TEST_CLASS_NAME TestCkTileGroupedGemmQuant_Tensor
|
||||
#include "test_grouped_gemm_quant_ut_cases.inc"
|
||||
#undef TEST_CLASS_NAME
|
||||
@@ -1,6 +1,6 @@
|
||||
#pragma once
|
||||
|
||||
TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
|
||||
TYPED_TEST(TEST_CLASS_NAME, Basic)
|
||||
{
|
||||
const int group_count = 8;
|
||||
std::vector<int> Ms;
|
||||
@@ -29,7 +29,7 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, Basic)
|
||||
|
||||
// No Hot Loop Test Case, this is to test the correctness of the kernel when there is no hot loop
|
||||
// Using 256x256x128 to match the test kernel's tile size (M_Tile=256, N_Tile=256, K_Tile=128)
|
||||
TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
|
||||
TYPED_TEST(TEST_CLASS_NAME, SmallUniform) //
|
||||
{
|
||||
const int group_count = 2;
|
||||
std::vector<int> Ms;
|
||||
@@ -55,3 +55,29 @@ TYPED_TEST(TestCkTileGroupedGemmQuant, SmallUniform) //
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
TYPED_TEST(TEST_CLASS_NAME, OddTail) //
|
||||
{
|
||||
const int group_count = 2;
|
||||
std::vector<int> Ms;
|
||||
std::vector<int> Ns;
|
||||
std::vector<int> Ks;
|
||||
std::vector<int> stride_As;
|
||||
std::vector<int> stride_Bs;
|
||||
std::vector<int> stride_Cs;
|
||||
std::vector<int> stride_AQs;
|
||||
std::vector<int> stride_BQs;
|
||||
for(int i = 0; i < group_count; i++)
|
||||
{
|
||||
Ms.push_back(256);
|
||||
Ns.push_back(256);
|
||||
Ks.push_back(128);
|
||||
|
||||
stride_As.push_back(0);
|
||||
stride_Bs.push_back(0);
|
||||
stride_Cs.push_back(0);
|
||||
stride_AQs.push_back(0);
|
||||
stride_BQs.push_back(0);
|
||||
}
|
||||
|
||||
this->Run(Ms, Ns, Ks, stride_As, stride_Bs, stride_Cs, stride_AQs, stride_BQs, group_count);
|
||||
}
|
||||
|
||||
@@ -17,23 +17,40 @@ template <typename Tuple>
|
||||
class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using AQDataType = std::tuple_element_t<4, Tuple>;
|
||||
using BDataType = std::tuple_element_t<5, Tuple>;
|
||||
using BQDataType = std::tuple_element_t<6, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using AQLayout = Row;
|
||||
using BQLayout = Col;
|
||||
static constexpr bool Persistent = true;
|
||||
using ALayout = std::tuple_element_t<0, Tuple>;
|
||||
using BLayout = std::tuple_element_t<1, Tuple>;
|
||||
using CLayout = std::tuple_element_t<2, Tuple>;
|
||||
using ADataType = std::tuple_element_t<3, Tuple>;
|
||||
using AQDataType = std::tuple_element_t<4, Tuple>;
|
||||
using BDataType = std::tuple_element_t<5, Tuple>;
|
||||
using BQDataType = std::tuple_element_t<6, Tuple>;
|
||||
using AccDataType = std::tuple_element_t<7, Tuple>;
|
||||
using CDataType = std::tuple_element_t<8, Tuple>;
|
||||
static constexpr auto QuantType = std::tuple_element_t<9, Tuple>::value;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using AQLayout = Row;
|
||||
using BQLayout = Col;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool PreshuffleB = std::tuple_element_t<10, Tuple>::value;
|
||||
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
static constexpr ck_tile::index_t get_k_from_preshuffled_warp_tile()
|
||||
{
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 64;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 128;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return sizeof(PrecType) == 2 ? 16 : 32;
|
||||
else
|
||||
return sizeof(PrecType) == 2 ? 32 : 64;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GroupedGemKernelParam_Mfma
|
||||
{
|
||||
@@ -52,7 +69,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
|
||||
static const ck_tile::index_t M_Warp_Tile = 32;
|
||||
static const ck_tile::index_t N_Warp_Tile = 32;
|
||||
static const ck_tile::index_t K_Warp_Tile = 16;
|
||||
static const ck_tile::index_t K_Warp_Tile =
|
||||
TestCkTileGroupedGemmQuant::template get_k_from_preshuffled_warp_tile<BDataType,
|
||||
M_Warp_Tile>();
|
||||
};
|
||||
|
||||
using grouped_gemm_kargs = ck_tile::QuantGroupedGemmHostArgs;
|
||||
@@ -66,8 +85,9 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer = false;
|
||||
constexpr bool TransposeC = false;
|
||||
constexpr bool DoubleSmemBuffer =
|
||||
PreshuffleB; // currently DoubleSmemBuffer is only supported for preshuffled B
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
@@ -90,7 +110,7 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
GroupedGemKernelParam::kPadN,
|
||||
GroupedGemKernelParam::kPadK,
|
||||
false,
|
||||
false,
|
||||
PreshuffleB,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
@@ -126,11 +146,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
BDataType,
|
||||
scheduler>>::type;
|
||||
|
||||
using GemmPipeline = typename std::conditional<
|
||||
QuantType == ck_tile::QuantType::BQuantGrouped,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>>::type;
|
||||
|
||||
using GemmPipeline = std::conditional_t<
|
||||
QuantType == ck_tile::QuantType::RowColQuant ||
|
||||
QuantType == ck_tile::QuantType::TensorQuant,
|
||||
ck_tile::GemmPipelineAgBgCrCompV3<QuantGemmProblem>,
|
||||
std::conditional_t<PreshuffleB == true,
|
||||
ck_tile::WPQuantBPipelineAgBgCrV2<QuantGemmProblem>,
|
||||
ck_tile::BQuantGemmPipelineAgBgCrCompV3<QuantGemmProblem>>>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
@@ -344,7 +366,18 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
bq_tensors[i].get_element_space_size_in_bytes()));
|
||||
|
||||
a_m_k_dev_buf[i]->ToDevice(a_m_k_tensors[i].data());
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
|
||||
if constexpr(PreshuffleB && QuantType == ck_tile::QuantType::BQuantGrouped)
|
||||
{
|
||||
auto b_shuffle_host =
|
||||
ck_tile::shuffle_b<GroupedGemKernelParam_Mfma>(b_k_n_tensors[i]);
|
||||
b_k_n_dev_buf[i]->ToDevice(b_shuffle_host.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
b_k_n_dev_buf[i]->ToDevice(b_k_n_tensors[i].data());
|
||||
}
|
||||
|
||||
aq_dev_buf[i]->ToDevice(aq_tensors[i].data());
|
||||
bq_dev_buf[i]->ToDevice(bq_tensors[i].data());
|
||||
c_m_n_dev_buf[i]->SetZero();
|
||||
@@ -485,3 +518,13 @@ class TestCkTileGroupedGemmQuant : public ::testing::Test
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
// Aliases for split test files
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_RowCol = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_Tensor = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
template <typename Tuple>
|
||||
using TestCkTileGroupedGemmQuant_BQuant = TestCkTileGroupedGemmQuant<Tuple>;
|
||||
|
||||
Reference in New Issue
Block a user