From b1d34eae288c5f39ba20498a57626d0bce7b8021 Mon Sep 17 00:00:00 2001 From: SamiAario-AMD Date: Fri, 19 Sep 2025 07:26:10 +0300 Subject: [PATCH] Add gemm weight preshuffle pk_int_t support (#2858) * Factor out the three separate copies of load_interleaved_pk_type into a common utility class * Add preprocessing with optional cache flushing and clearing of output for k_batch > 1 to the weight preshuffle GEMM example * Remove a duplicate function * Add support for B tensor type pk_int4_t for the weight preshuffle GEMM, with tests included * I4 support introduced more failing test cases that mirror the existing ones for F8 * Simplify the check for which tests to skip (they all have F8 as A tensor type) * Add a changelog entry * add the test for v2 wp pipeline, polish the code, add the support of int4 for v2 wp pipeline * have a workable version for atomic add * Revert "have a workable version for atomic add" This reverts commit 792377a590c26cfff9c8f545d9a9e8484a7422eb. --------- Co-authored-by: ThomasNing [ROCm/composable_kernel commit: 47cd0d5cff77658adc1c9f184c012ec3496e8214] --- CHANGELOG.md | 1 + .../ops/common/load_interleaved_pk_type.hpp | 58 +++++++++++++++++++ .../block/block_universal_gemm_as_bs_cr.hpp | 37 ++++-------- ..._pipeline_agmem_bgmem_creg_base_policy.hpp | 18 +++--- .../wp_pipeline_agmem_bgmem_creg_v1.hpp | 28 +++++---- .../wp_pipeline_agmem_bgmem_creg_v2.hpp | 28 +++++---- .../block_universal_gemm_as_aquant_bs_cr.hpp | 30 +++------- .../block_universal_gemm_as_bs_bquant_cr.hpp | 31 +++------- .../test_batched_gemm_ut_cases.inc | 3 +- .../test_gemm_pipeline_smoke_run_test.inc | 57 +----------------- .../test_gemm_pipeline_kernel_types.hpp | 25 ++++---- .../test_gemm_pipeline_ut_cases.inc | 8 +-- .../test_gemm_pipeline_util.hpp | 36 +++++++++--- 13 files changed, 183 insertions(+), 177 deletions(-) create mode 100644 include/ck_tile/ops/common/load_interleaved_pk_type.hpp diff --git a/CHANGELOG.md b/CHANGELOG.md index dafe1b5c87..6dd06195c9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj ## Composable Kernel 1.2.0 for ROCm 7.0.0 ### Added +* Added support for B Tensor type pk_int4_t in the CK TILE weight preshuffle GEMM. * Added support for B Tensor Preshuffle in CK TILE Grouped GEMM. * Added a basic copy kernel example and supporting documentation for new CK Tile developers. * Added support for bf16, f32, and f16 for 2D and 3D NGCHW grouped convolution backward data diff --git a/include/ck_tile/ops/common/load_interleaved_pk_type.hpp b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp new file mode 100644 index 0000000000..f8432b9da0 --- /dev/null +++ b/include/ck_tile/ops/common/load_interleaved_pk_type.hpp @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" +#include "ck_tile/ops/elementwise.hpp" + +namespace ck_tile { + +template +struct is_pk_int4 : std::false_type +{ +}; +template <> +struct is_pk_int4 : std::true_type +{ +}; + +template +struct InterleavedPKTypeLoader +{ + template + CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, + const WarpWindow& warp_window) + { + const element_wise::PassThroughPack8 elementwise_op{}; + + static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); + constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; + const auto in_dstr_tensors = load_tile(warp_window); + + using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); + static_for<0, thread_buffer_size, 1>{}([&](auto i) { + elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), + in_dstr_tensors.get_thread_buffer().template get_as()[i]); + }); + } +}; + +template +CK_TILE_DEVICE void load_int4_tile(WarpTile& dst, const WarpWindow& src) +{ + if constexpr(is_pk_int4>::value) + { + InterleavedPKTypeLoader::load_interleaved_pk_type(dst, src); + } + else + { + dst = load_tile(src); + } +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp index e1b0792ecf..94adb42880 100644 --- a/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp +++ b/include/ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" @@ -13,7 +14,9 @@ namespace ck_tile { // A is block window on shared memory // B is block window on shared memory // C is block distributed tensor -template +template struct BlockUniversalGemmAsBsCr { private: @@ -91,6 +94,7 @@ struct BlockUniversalGemmAsBsCr using ComputeDataType = remove_cvref_t; using CDataType = remove_cvref_t; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -179,25 +183,6 @@ struct BlockUniversalGemmAsBsCr return b_block_dstr_encode; } - private: - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - constexpr index_t UnaryOpSize = 8; - const element_wise::PassThroughPack8 elementwise_op{}; - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } - template struct BlockGemmImpl { @@ -239,7 +224,7 @@ struct BlockUniversalGemmAsBsCr if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -247,7 +232,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { @@ -317,7 +302,7 @@ struct BlockUniversalGemmAsBsCr { if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else if constexpr(ALoadTranspose) { @@ -329,7 +314,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else if constexpr(BLoadTranspose) { @@ -468,7 +453,7 @@ struct BlockUniversalGemmAsBsCr if constexpr(std::is_same_v) { - load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else if constexpr(ALoadTranspose) { @@ -480,7 +465,7 @@ struct BlockUniversalGemmAsBsCr } if constexpr(std::is_same_v) { - load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else if constexpr(BLoadTranspose) { diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp index 71ca907c07..f1c8f2ec9b 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp @@ -289,13 +289,17 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy { using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using WarpTile = typename Problem::BlockGemmShape::WarpTile; - using WarpGemm = WarpGemmDispatcher; + using BTypeToUse = + std::conditional_t, + typename Problem::ADataType, + typename Problem::BDataType>; + using WarpGemm = WarpGemmDispatcher; using BlockWeightPreshufflePolicy = BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy::value && !is_detected::value, - bool>* = nullptr> + bool>* = nullptr, + index_t UnaryOpSize_ = 8> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -310,14 +312,14 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 NIterPerWarp> b_flat_dram_windows; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + + statically_indexed_array, NIterPerWarp> b_warp_tensor; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_2; static_for<0, NIterPerWarp, 1>{}([&](auto nIter) { @@ -327,7 +329,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -375,7 +378,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -408,7 +412,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -445,7 +450,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV1 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_2(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_2(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp index 356ad91448..670f4b0575 100644 --- a/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_v2.hpp @@ -4,6 +4,7 @@ #pragma once #include "ck_tile/core.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/host/concat.hpp" #include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp" @@ -514,7 +515,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 typename AElementFunction, typename std::enable_if_t::value && !is_detected::value, - bool>* = nullptr> + bool>* = nullptr, + index_t UnaryOpSize_ = 8> CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp, const AElementFunction& a_element_func, const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp, @@ -631,19 +633,19 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 b_flat_distribution); // pingpong buffer for B + using BTypeToUse = + std::conditional_t, ADataType, BDataType>; + using BTileType = decltype(make_static_distributed_tensor(b_flat_distribution)); + statically_indexed_array< statically_indexed_array, NIterPerWarp> b_flat_dram_windows; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_ping; - statically_indexed_array< - statically_indexed_array, - NIterPerWarp> + statically_indexed_array, NIterPerWarp> b_warp_tensor_pong; // Prefetch A0 @@ -659,7 +661,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); // move B window to next flat K @@ -706,7 +709,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -782,7 +786,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_ping(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_ping(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); @@ -862,7 +867,8 @@ struct WeightPreshufflePipelineAGmemBGmemCRegV2 move_tile_window(b_flat_dram_windows(nIter)(kIter), {nIter * NFlatPerBlockPerIter, kIter * KFlatPerBlockPerIter}); - b_warp_tensor_pong(nIter)(kIter) = load_tile(b_flat_dram_windows(nIter)(kIter)); + load_int4_tile( + b_warp_tensor_pong(nIter)(kIter), b_flat_dram_windows(nIter)(kIter)); }); }); diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp index 182d9251b1..f75d02f1a6 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp @@ -5,19 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" namespace ck_tile { -template +template struct BlockGemmAQuantBase { using AQDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; - static constexpr index_t UnaryOpSize = UnaryOpSize_; template CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) { @@ -42,23 +42,6 @@ struct BlockGemmAQuantBase } return scale_reg_f; } - - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - const element_wise::PassThroughPack8 elementwise_op{}; - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } }; // A is block window on shared memory @@ -66,7 +49,9 @@ struct BlockGemmAQuantBase // Consecutive kQuantGroupSize elements of A are quantized with a separate scale. // B is block window on shared memory // C is block distributed tensor -template +template struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { private: @@ -172,6 +157,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase using Base = BlockGemmAQuantBase; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -292,7 +278,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -302,7 +288,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { diff --git a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp index 7e28ea8fa9..077d0d8fe2 100644 --- a/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp +++ b/include/ck_tile/ops/gemm_group_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp @@ -5,19 +5,19 @@ #include "ck_tile/core.hpp" #include "ck_tile/core/arch/arch.hpp" +#include "ck_tile/ops/common/load_interleaved_pk_type.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/elementwise.hpp" namespace ck_tile { -template +template struct BlockGemmBQuantBase { using BQDataType = remove_cvref_t; using ComputeDataType = remove_cvref_t; - static constexpr index_t UnaryOpSize = UnaryOpSize_; template CK_TILE_DEVICE static float cvt_scale_to_fp32(T scale) { @@ -42,24 +42,6 @@ struct BlockGemmBQuantBase } return scale_reg_f; } - - // can be inherited from A - template - CK_TILE_DEVICE static void load_interleaved_pk_type(WarpTile& warp_tile, - const WarpWindow& warp_window) - { - const element_wise::PassThroughPack8 elementwise_op{}; - - static_assert(WarpTile::get_thread_buffer_size() % UnaryOpSize == 0); - constexpr index_t thread_buffer_size = WarpTile::get_thread_buffer_size() / UnaryOpSize; - const auto in_dstr_tensors = load_tile(warp_window); - - using ComputeVectorType = ComputeDataType __attribute__((ext_vector_type(UnaryOpSize))); - static_for<0, thread_buffer_size, 1>{}([&](auto i) { - elementwise_op(warp_tile.get_thread_buffer().template get_as()(i), - in_dstr_tensors.get_thread_buffer().template get_as()[i]); - }); - } }; // A is block window on shared memory @@ -67,7 +49,9 @@ struct BlockGemmBQuantBase // Consecutive kQuantGroupSize elements of B are quantized with a separate scale. // B is block window on shared memory // C is block distributed tensor -template +template struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { private: @@ -170,6 +154,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase using Base = BlockGemmBQuantBase; + using Loader = remove_cvref_t>; using WarpGemm = remove_cvref_t; static constexpr index_t KIterPerWarp = Traits::KIterPerWarp; @@ -291,7 +276,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(a_warp_tile_, a_block_window); + Loader::load_interleaved_pk_type(a_warp_tile_, a_block_window); } else { @@ -301,7 +286,7 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase { static_assert(std::is_same_v || std::is_same_v); - Base::load_interleaved_pk_type(b_warp_tile_, b_block_window); + Loader::load_interleaved_pk_type(b_warp_tile_, b_block_window); } else { diff --git a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc index 035377734b..8f24c9bfe1 100644 --- a/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc +++ b/test/ck_tile/batched_gemm/test_batched_gemm_ut_cases.inc @@ -29,7 +29,8 @@ TYPED_TEST(TestCkTileBatchedGemm, Basic) {256, 256, 64, 8}, {256, 256, 64, 16}}; - if(ck_tile::get_device_name() != "gfx950") { + if(ck_tile::get_device_name() != "gfx950") + { gemmParams.emplace_back(256, 256, 128, 2); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc index ab74e4e7b1..57feefceab 100644 --- a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -2,6 +2,8 @@ // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once +#include "ck_tile/host/permute_pk_int4.hpp" + template static constexpr inline auto is_row_major(Layout layout_) { @@ -91,61 +93,6 @@ void permute_tensor_b(Tensor& tensor) } } -template -void permute_vectors_i4x4_b(Tensor& tensor) -{ - const ck_tile::index_t K = tensor.get_length(0); - const ck_tile::index_t N = tensor.get_length(1); - // vector pk_i4x4 permute - for(int i = 0; i < N; i++) - { - for(int j = 0; j < K; j += 8) - { - int8_t input[8]; - - for(int k = 0; k < 4; k++) - { - int8_t i4x2 = tensor(j + k * 2, i).data; - input[k * 2 + 0] = (i4x2 >> 4) & 0xf; - input[k * 2 + 1] = (i4x2 >> 0) & 0xf; - } - - // permute 01234567->20643175 - { - int8_t hi = input[2]; - int8_t lo = input[0]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 0, i) = i4x2; - } - - { - int8_t hi = input[6]; - int8_t lo = input[4]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 2, i) = i4x2; - } - - { - int8_t hi = input[3]; - int8_t lo = input[1]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 4, i) = i4x2; - } - - { - int8_t hi = input[7]; - int8_t lo = input[5]; - int8_t i4x2 = (hi << 4) | lo; - - tensor(j + 6, i) = i4x2; - } - } - } -} - template ; -using WeightPreshuffle = - ck_tile::integral_constant; - -// Adding alias for the F8 parameters to facilitate skipping tests. -// This alias can be removed once test failures are fixed. -using F8Types = std::tuple; +using WeightPreshuffleV1 = + ck_tile::integral_constant; +using WeightPreshuffleV2 = + ck_tile::integral_constant; // clang-format off using KernelTypesWeightPreshuffle = ::testing::Types< - std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffle>, - std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffle> -#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 - , F8Types + std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV1>, + std::tuple< Row, Col, Row, F16, F16, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, BF16, BF16, F32, BF16, Default, WeightPreshuffleV1> +#if !CK_TILE_USE_WMMA || CK_TILE_USE_OCP_FP8 + , + std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV1>, + std::tuple< Row, Col, Row, F8, F8, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV2>, + std::tuple< Row, Col, Row, F8, I4, F32, F16, Default, WeightPreshuffleV1> #endif >; diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc index 389e0d53ea..bb56c63413 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_ut_cases.inc @@ -20,7 +20,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -48,7 +48,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x128x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -77,7 +77,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_128x2048x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } @@ -106,7 +106,7 @@ TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x128x4096) TYPED_TEST(TEST_SUITE_NAME, GemmPreshuffle_1024x2048x128) { - if constexpr(std::is_same_v) + if constexpr(std::is_same_v, F8>) { GTEST_SKIP() << "Skipping this test due to failures with F8"; } diff --git a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp index 42d0149498..62f819ac1e 100644 --- a/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp +++ b/test/ck_tile/gemm_weight_preshuffle/test_gemm_pipeline_util.hpp @@ -8,6 +8,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" #include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/host/permute_pk_int4.hpp" #include "ck_tile/ops/epilogue.hpp" #include "ck_tile/ops/gemm.hpp" @@ -34,20 +35,31 @@ auto calculate_rtol_atol(const ck_tile::index_t K, enum struct GemmPipelineType { - WeightPreshuffle + WeightPreshuffleV1, + WeightPreshuffleV2 }; template struct GemmPipelineTypeSelector; template -struct GemmPipelineTypeSelector +struct GemmPipelineTypeSelector { using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV1; using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV1; - static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffle"; } + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV1"; } }; + +template +struct GemmPipelineTypeSelector +{ + using base_pipeline = ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2; + using pipeline = ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2; + + static constexpr auto GetName() { return "GemmPipelineAgBgCrWeightPreshuffleV2"; } +}; + template struct config { @@ -122,7 +134,8 @@ class TestCkTileGemmPipeline : public ::testing::Test constexpr bool kPadK = PadK; constexpr bool preshuffle = Preshuffle; - constexpr bool DoubleSmemBuffer = false; + constexpr bool DoubleSmemBuffer = + (PipelineType == GemmPipelineType::WeightPreshuffleV2) ? true : false; // TODO: For now - but this should also be a test parameter constexpr bool TransposeC = false; @@ -391,10 +404,19 @@ class TestCkTileGemmPipeline : public ::testing::Test ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); - ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); - a_m_k_dev_buf.ToDevice(a_m_k.data()); - b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + ck_tile::HostTensor b_shuffle_host = shuffle_b(b_k_n); + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_shuffle_host_dev = b_shuffle_host; + ck_tile::permute_vectors_i4x4_b(b_shuffle_host_dev); + b_k_n_dev_buf.ToDevice(b_shuffle_host_dev.data()); + } + else + { + b_k_n_dev_buf.ToDevice(b_shuffle_host.data()); + } c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero();