Merge commit 'a46b725992bdefad16d1c30dcfe4bb8441462907' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-27 19:11:23 +00:00
parent d3e72e87c4
commit 44a0e1afdb
5 changed files with 278 additions and 33 deletions

View File

@@ -182,6 +182,7 @@ struct GemmConfigPreshuffleDecode : public GemmConfigBase
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Default;
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_PRESHUFFLE_V2;
static constexpr bool Preshuffle = true;
static constexpr bool Persistent = true;
static constexpr bool DoubleSmemBuffer = true;
};

View File

@@ -167,6 +167,113 @@ float grouped_gemm(const std::vector<grouped_gemm_kargs>& gemm_descs,
return ave_time = BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
}
template <typename GemmConfig,
typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CDataType>
float grouped_gemm_tileloop(const ck_tile::stream_config& s,
const ck_tile::index_t num_groups,
void* kargs_ptr,
bool splitk)
{
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
ck_tile::
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>>;
using TilePartitioner =
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
GemmConfig::TileParitionerGroupNum,
GemmConfig::TileParitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
GemmConfig::kPadN,
GemmConfig::kPadK,
GemmConfig::DoubleSmemBuffer,
ALayout,
BLayout,
CLayout,
GemmConfig::TransposeC,
GemmConfig::UseStructuredSparsity,
GemmConfig::Persistent,
GemmConfig::NumWaveGroups,
GemmConfig::Preshuffle>;
float ave_time{0};
const auto Run = [&](const auto memory_operation_) {
constexpr auto scheduler = GemmConfig::Scheduler;
constexpr auto memory_operation = memory_operation_.value;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
scheduler>;
using GemmPipeline = typename PipelineTypeTraits<
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
using GemmEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
ADataType,
BDataType,
ck_tile::tuple<>, // DsDataType (empty for no D tensors)
AccDataType,
CDataType,
ck_tile::tuple<>, // DsLayout (empty for no D tensors)
CLayout,
ck_tile::element_wise::PassThrough,
TilePartitioner::MPerBlock,
TilePartitioner::NPerBlock,
GemmConfig::M_Warp,
GemmConfig::N_Warp,
GemmConfig::M_Warp_Tile,
GemmConfig::N_Warp_Tile,
GemmConfig::K_Warp_Tile,
UniversalGemmProblem::TransposeC,
memory_operation>>;
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
const dim3 blocks = Kernel::BlockSize();
const dim3 grids = Kernel::MaxOccupancyGridSize(s);
if(s.log_level_ > 0)
{
std::cout << "Launching kernel: " << Kernel::GetName() << " with args:" << " grid: {"
<< grids.x << ", " << grids.y << ", " << grids.z << "}" << ", blocks: {"
<< blocks.x << ", " << blocks.y << ", " << blocks.z << "}" << std::endl;
}
ave_time =
ck_tile::launch_kernel(s,
ck_tile::make_kernel<GemmConfig::kBlockPerCu>(
Kernel{},
grids,
blocks,
0,
ck_tile::cast_pointer_to_constant_address_space(kargs_ptr),
num_groups));
return ave_time;
};
if(splitk)
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::atomic_add>{});
}
else
{
Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
ck_tile::memory_operation_enum::set>{});
}
return ave_time;
}
#include "run_grouped_gemm_example.inc"
template <typename GemmConfig, typename PrecType>

View File

@@ -70,23 +70,13 @@ float invoke_gemm(int n_warmup,
}
else
{
if(GemmConfig::Preshuffle)
{
// not supported yet
throw std::runtime_error(
"Persistent grouped gemm with preshuffle is not supported yet");
}
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to haveCollapse
// commentComment on line L74tenpercent commented on Sep 5, 2025 tenpercenton Sep 5,
// 2025ContributorMore actionsdid you intend to remove the comment?Write a replyResolve
// commentCode has comments. Press enter to view. the gemm problems known on the host.
// Instead, we can just pass the pointer to the kernel and let the workgroups figure out
// which tiles to work on. This is useful when the gemm problems are generated dynamically.
// In this example however, we generate the `kargs` using the known gemm_descs,
// and copy the gemm descriptions to the device memory.
// The contents of the memory pointed to by `kargs_ptr` pointer could be
// written by e.g. another kernel from earlier stage.
// NOTE: With the persistent TileLoop kernel, we do not necessarily need to have the gemm
// problems known on the host. Instead, we can just pass the pointer to the kernel and let
// the workgroups figure out which tiles to work on. This is useful when the gemm problems
// are generated dynamically. In this example however, we generate the `kargs` using the
// known gemm_descs, and copy the gemm descriptions to the device memory. The contents of
// the memory pointed to by `kargs_ptr` pointer could be written by e.g. another kernel from
// earlier stage.
std::vector<ck_tile::GemmTransKernelArg<>> kargs;
void* kargs_ptr = gemm_workspace.GetDeviceBuffer();