mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 14:29:05 +00:00
Added Support for tile_grouped_gemm_preshuffle example (#2993)
* Added Support for tile_grouped_gemm_preshuffle example * Resolved PR comments + Added unit tests for preshuffle with persistent * Fixed CMake Build config error * Fix clang error that caused CI to fail * Fix clang formatting * Fix clang issue * Fix errors causing test cases to fail * Fix grouped_gemm_preshuffle unit test failure * Resolve PR comments * Cleaned code + removed unnecassary changes * Update test/ck_tile/grouped_gemm_preshuffle/test_grouped_gemm_preshuffle_util.hpp Co-authored-by: Aviral Goel <aviral.goel@amd.com> * Fix clang formatting * Made changes to improve code readability --------- Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com> Co-authored-by: Thomas Ning <Thomas.Ning@amd.com> Co-authored-by: Aviral Goel <aviral.goel@amd.com>
This commit is contained in:
@@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
|
||||
static constexpr bool Preshuffle = true;
|
||||
static constexpr bool Persistent = true;
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
};
|
||||
|
||||
|
||||
@@ -167,6 +167,113 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType>
|
||||
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
|
||||
const ck_tile::index_t num_groups,
|
||||
void* kargs_ptr,
|
||||
bool splitk)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
GemmConfig::Persistent,
|
||||
GemmConfig::NumWaveGroups,
|
||||
GemmConfig::Preshuffle>;
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
|
||||
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
|
||||
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
|
||||
}
|
||||
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
num_groups));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(splitk)
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
#include "run_grouped_gemm_example.inc"
|
||||
|
||||
template <typename GemmConfig, typename PrecType>
|
||||
|
||||
@@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup,
|
||||
}
|
||||
else
|
||||
{
|
||||
if(GemmConfig::Preshuffle)
|
||||
{
|
||||
// not supported yet
|
||||
throw std::runtime_error(
|
||||
"Persistent grouped gemm with preshuffle is not supported yet");
|
||||
}
|
||||
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
|
||||
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
|
||||
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
|
||||
// commentCode has comments. Press enter to view. the gemm problems known on the host.
|
||||
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
|
||||
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
|
||||
// In this example however, we generate the `kargs` using the known gemm_descs,
|
||||
// and copy the gemm descriptions to the device memory.
|
||||
// The contents of the memory pointed to by `kargs_ptr` pointer could be
|
||||
// written by e.g. another kernel from earlier stage.
|
||||
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm
|
||||
// problems known on the host. Instead, we can just pass the pointer to the kernel and let
|
||||
// the workgroups figure out which tiles to work on. This is useful when the gemm problems
|
||||
// are generated dynamically. In this example however, we generate the `kargs` using the
|
||||
// known gemm_descs, and copy the gemm descriptions to the device memory. The contents of
|
||||
// the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from
|
||||
// earlier stage.
|
||||
|
||||
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
|
||||
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();
|
||||
|
||||
@@ -8,11 +8,13 @@
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_grouped_gemm_preshuffle_util.hpp"
|
||||
|
||||
using F16 = ck_tile::half_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F32 = float;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using F16 = ck_tile::half_t;
|
||||
using F8 = ck_tile::fp8_t;
|
||||
using F32 = float;
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using False = std::false_type;
|
||||
using True = std::true_type;
|
||||
|
||||
// Custom tuple-like structure for kernel configuration
|
||||
template <typename ALayout_,
|
||||
@@ -22,6 +24,7 @@ template <typename ALayout_,
|
||||
typename BDataType_,
|
||||
typename AccDataType_,
|
||||
typename CDataType_,
|
||||
typename Persistent_,
|
||||
int M_Tile_val_,
|
||||
int N_Tile_val_,
|
||||
int K_Tile_val_,
|
||||
@@ -35,6 +38,7 @@ struct KernelConfig
|
||||
using BDataType = BDataType_;
|
||||
using AccDataType = AccDataType_;
|
||||
using CDataType = CDataType_;
|
||||
using Persistent = Persistent_;
|
||||
|
||||
static constexpr int M_Tile_ = M_Tile_val_;
|
||||
static constexpr int N_Tile_ = N_Tile_val_;
|
||||
@@ -44,11 +48,16 @@ struct KernelConfig
|
||||
|
||||
// clang-format off
|
||||
using KernelTypes = ::testing::Types<
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, M_Tile, N_Tile, K_Tile, BlockPerCu
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, 128, 128, 128, 2>
|
||||
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, Persistent ,M_Tile, N_Tile, K_Tile, BlockPerCu
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, False, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, False, 128, 128, 128, 2>,
|
||||
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 16, 64, 256, 1>,
|
||||
KernelConfig< Row, Col, Row, F16, F16, F32, F16, True, 128, 128, 128, 2>,
|
||||
KernelConfig< Row, Col, Row, F8, F8, F32, F16, True, 128, 128, 128, 2>
|
||||
>;
|
||||
// clang-format on
|
||||
|
||||
|
||||
@@ -39,9 +39,13 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
using BDataType = typename Tuple::BDataType;
|
||||
using AccDataType = typename Tuple::AccDataType;
|
||||
using CDataType = typename Tuple::CDataType;
|
||||
using PrecType = BDataType;
|
||||
using DsLayout = ck_tile::tuple<>; // not used
|
||||
using DsDataType = ck_tile::tuple<>; // not used
|
||||
|
||||
using DsLayout = ck_tile::tuple<>; // not used
|
||||
using DsDataType = ck_tile::tuple<>; // not used
|
||||
|
||||
// Get the persistent value from ck_tile::bool_constant
|
||||
using PersistentType = typename Tuple::Persistent;
|
||||
static constexpr bool Persistent = PersistentType::value;
|
||||
|
||||
static const bool kPadM = false;
|
||||
static const bool kPadN = false;
|
||||
@@ -231,6 +235,129 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename ALayout, typename BLayout, typename CLayout>
|
||||
void invoke_grouped_gemm_persistent(const std::vector<grouped_gemm_kargs>& gemm_descs,
|
||||
const ck_tile::stream_config& s,
|
||||
void* kargs_ptr)
|
||||
{
|
||||
using GemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
using TilePartitioner = ck_tile::
|
||||
GemmSpatiallyLocalTilePartitioner<GemmShape, TileParitionerGroupNum, TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
// Enable persistent mode for preshuffle
|
||||
using GemmUniversalTraits =
|
||||
ck_tile::TileGemmUniversalTraits</*kPadM*/ true,
|
||||
/*kPadN*/ true,
|
||||
/*kPadK*/ true,
|
||||
DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
TransposeC,
|
||||
/*UseStructuredSparsity*/ false,
|
||||
/*Persistent*/ true, // Enable persistent mode
|
||||
/*NumWaveGroups*/ 1,
|
||||
/*Preshuffle*/ true>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline =
|
||||
ck_tile::BaseWeightPreshufflePipelineAGmemBGmemCRegV2<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = gemm_descs[0].k_batch * K_Tile;
|
||||
const ck_tile::index_t K_split = (gemm_descs[0].K + k_grain - 1) / k_grain * K_Tile;
|
||||
const ck_tile::index_t num_loop =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
TileParitionerGroupNum,
|
||||
TileParitionerM01>::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run = [&](const auto has_hot_loop_,
|
||||
const auto tail_number_,
|
||||
const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
using UniversalGemmProblem =
|
||||
ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
ck_tile::GemmPipelineScheduler::Default,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
using GemmPipeline =
|
||||
ck_tile::WeightPreshufflePipelineAGmemBGmemCRegV2<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKargs(gemm_descs);
|
||||
EXPECT_TRUE(Kernel::IsSupportedArgument(kargs));
|
||||
const dim3 grids = Kernel::GridSize(gemm_descs);
|
||||
const dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
ck_tile::hip_check_error(hipMemcpyWithStream(kargs_ptr,
|
||||
kargs.data(),
|
||||
get_workspace_size(gemm_descs),
|
||||
hipMemcpyHostToDevice,
|
||||
s.stream_id_));
|
||||
|
||||
ave_time = ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<kBlockPerCu>(
|
||||
Kernel{},
|
||||
grids,
|
||||
blocks,
|
||||
0,
|
||||
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
|
||||
gemm_descs.size()));
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(gemm_descs[0].k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// EXPECT TO FAIL because splitk is not supported
|
||||
EXPECT_FALSE(true);
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
}
|
||||
|
||||
public:
|
||||
void Run(const std::vector<int>& Ms,
|
||||
const std::vector<int>& Ns,
|
||||
@@ -350,9 +477,20 @@ class TestCkTileGroupedGemmPreshuffle : public ::testing::Test
|
||||
ck_tile::DeviceMem gemm_workspace;
|
||||
gemm_workspace.Realloc(get_workspace_size(gemm_descs));
|
||||
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
invoke_grouped_gemm_persistent<ALayout, BLayout, CLayout>(
|
||||
gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
invoke_grouped_gemm<ALayout, BLayout, CLayout>(
|
||||
gemm_descs,
|
||||
ck_tile::stream_config{nullptr, false, 1},
|
||||
gemm_workspace.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
// Copy results back to host for validation
|
||||
for(int i = 0; i < group_count; i++)
|
||||
|
||||
Reference in New Issue
Block a user