Added Support for tile_grouped_gemm_preshuffle example (#2993)

* Added Support for tile_grouped_gemm_preshuffle example

* Resolved PR comments + Added unit tests for preshuffle with persistent

* Fixed CMake Build config error

* Fix clang error that caused CI to fail

* Fix clang formatting

* Fix clang issue

* Fix errors causing test cases to fail

* Fix grouped_gemm_preshuffle unit test failure

* Resolve PR comments

* Cleaned code + removed unnecassary changes

* Update test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp

Co-authored-by: Aviral Goel <aviral.goel@amd.com>

* Fix clang formatting

* Made changes to improve code readability

---------

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
Co-authored-by: Aviral Goel <aviral.goel@amd.com>
This commit is contained in:
mkumar16-amd
2025-10-28 00:01:19 +05:30
committed by GitHub
parent 6c2ca1211a
commit a46b725992
5 changed files with 278 additions and 33 deletions

View File

@@ -8,11 +8,13 @@
#include "ck_tile/host.hpp"
#include "test_grouped_gemm_preshuffle_util.hpp"
using F16 = ck_tile::half_t;
using F8 = ck_tile::fp8_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using F16 = ck_tile::half_t;
using F8 = ck_tile::fp8_t;
using F32 = float;
using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using False = std::false_type;
using True = std::true_type;
// Custom tuple-like structure for kernel configuration
template <typename ALayout_,
@@ -22,6 +24,7 @@ template <typename ALayout_,
typename BDataType_,
typename AccDataType_,
typename CDataType_,
typename Persistent_,
int M_Tile_val_,
int N_Tile_val_,
int K_Tile_val_,
@@ -35,6 +38,7 @@ struct KernelConfig
using BDataType = BDataType_;
using AccDataType = AccDataType_;
using CDataType = CDataType_;
using Persistent = Persistent_;
static constexpr int M_Tile_ = M_Tile_val_;
static constexpr int N_Tile_ = N_Tile_val_;
@@ -44,11 +48,16 @@ struct KernelConfig
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2>
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>,
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>
>;
// clang-format on

View File

@@ -39,9 +39,13 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
using BDataType = typename Tuple::BDataType;
using AccDataType = typename Tuple::AccDataType;
using CDataType = typename Tuple::CDataType;
using PrecType = BDataType;
using DsLayout = ck_tile::tuple<>; // not used
using DsDataType = ck_tile::tuple<>; // not used
using DsLayout = ck_tile::tuple<>; // not used
using DsDataType = ck_tile::tuple<>; // not used
// Get the persistent value from ck_tile::bool_constant
using PersistentType = typename Tuple::Persistent;
static constexpr bool Persistent = PersistentType::value;
static const bool kPadM = false;
static const bool kPadN = false;
@@ -231,6 +235,129 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
private:
template <typename ALayout, typename BLayout, typename CLayout>
void invoke_grouped_gemm_persistent(const std::vector<grouped_gemm_kargs>& gemm_descs,
const ck_tile::stream_config& s,
void* kargs_ptr)
{
using GemmShape =
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
using TilePartitioner = ck_tile::
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
// Enable persistent mode for preshuffle
using GemmUniversalTraits =
ck_tile::TileGemmUniversalTraits</*kPadM*/ true,
/*kPadN*/ true,
/*kPadK*/ true,
DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
TransposeC,
/*UseStructuredSparsity*/ false,
/*Persistent*/ true, // Enable persistent mode
/*NumWaveGroups*/ 1,
/*Preshuffle*/ true>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
const ck_tile::index_t num_loop =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
TileParitionerGroupNum,
TileParitionerM01>::GetLoopNum(K_split);
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
float ave_time{0};
const auto Run = [&](const auto has_hot_loop_,
const auto tail_number_,
const auto memory_operation_) {
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem =
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
ck_tile::GemmPipelineScheduler::Default,
has_hot_loop_v,
tail_number_v>;
using GemmPipeline =
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<
ck_tile::CShuffleEpilogueProblem<ADataType,
BDataType,
DsDataType,
AccDataType,
CDataType,
DsLayout,
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
M_Warp,
N_Warp,
M_Warp_Tile,
N_Warp_Tile,
K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKargs(gemm_descs);
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
const dim3 grids = Kernel::GridSize(gemm_descs);
const dim3 blocks = Kernel::BlockSize();
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
kargs.data(),
get_workspace_size(gemm_descs),
hipMemcpyHostToDevice,
s.stream_id_));
ave_time = ck_tile::launch_kernel(
s,
ck_tile::make_kernel<kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
gemm_descs.size()));
return ave_time;
};
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
if(gemm_descs[0].k_batch == 1)
{
Run(has_hot_loop_,
tail_number_,
ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
else
{
// EXPECT TO FAIL because splitk is not supported
EXPECT_FALSE(true);
}
};
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
public:
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
@@ -350,9 +477,20 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
ck_tile::DeviceMem gemm_workspace;
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
invoke_grouped_gemm<ALayout, BLayout, CLayout>(gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
if constexpr(Persistent)
{
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
}
else
{
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
gemm_descs,
ck_tile::stream_config{nullptr, false, 1},
gemm_workspace.GetDeviceBuffer());
}
// Copy results back to host for validation
for(int i = 0; i < group_count; i++)