Merge branch 'develop' into ginolu/add_wgmfma_dispatcher

This commit is contained in:
Gino Lu
2025-09-08 19:10:23 -05:00
147 changed files with 10722 additions and 1484 deletions

View File

@@ -33,15 +33,14 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
@@ -74,9 +73,6 @@ struct BlockWeightPreshuffleASmemBSmemCRegV1
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = BlockPolicy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);

View File

@@ -266,6 +266,10 @@ struct GroupedGemmKernel
const tuple<index_t, index_t>& block_idx_2d,
const index_t block_idx_z) const
{
static_assert(GemmPipeline::DoubleSmemBuffer || !GemmPipeline::Preshuffle,
"SingleSmemBuffer and Preshuffle cannot both be enabled simultaneously!");
const auto [iM, iN] = block_idx_2d;
const index_t i_m = __builtin_amdgcn_readfirstlane(iM * TilePartitioner::MPerBlock);
@@ -282,11 +286,15 @@ struct GroupedGemmKernel
// allocate LDS
__shared__ char smem_ptr_0[GetSmemSize()];
// TO DO:
// Can we simplify this branching logic?
if constexpr(GemmPipeline::DoubleSmemBuffer == true)
{
__shared__ char smem_ptr_1[GetSmemSize()];
if constexpr(UsePersistentKernel)
if constexpr(UsePersistentKernel || GemmPipeline::Preshuffle)
{
RunGemmWithPipelineSelection2LDS(a_ptr,
b_ptr,
c_ptr,
@@ -296,9 +304,11 @@ struct GroupedGemmKernel
splitk_batch_offset,
i_m,
i_n);
return;
}
else
{
Base::RunGemm2LDS({a_ptr},
{b_ptr},
{/*ds_ptr*/},
@@ -311,14 +321,14 @@ struct GroupedGemmKernel
i_n);
}
}
else
else // SingleSmemBuffer
{
if constexpr(UsePersistentKernel)
{
RunGemmWithPipelineSelection(
a_ptr, b_ptr, c_ptr, smem_ptr_0, kargs, splitk_batch_offset, i_m, i_n);
}
else
else // Non-persistent kernel
{
Base::RunGemm({a_ptr},
{b_ptr},
@@ -438,17 +448,34 @@ struct GroupedGemmKernel
// Get hot-loop and tail configuration
const index_t num_loop = __builtin_amdgcn_readfirstlane(
TilePartitioner::GetLoopNum(splitk_batch_offset.splitted_k));
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
const TailNumber tail_num = GemmPipeline::GetBlockLoopTailNum(num_loop);
// Run GEMM pipeline
const auto& c_block_tile = GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
// Run GEMM pipeline with compile-time branching
const auto& c_block_tile = [&]() {
if constexpr(GemmPipeline::Preshuffle)
{
// Preshuffle version - without has_hot_loop parameter
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
else
{
// Regular version - with has_hot_loop parameter
const bool has_hot_loop = GemmPipeline::BlockHasHotloop(num_loop);
return GemmPipeline{}.template operator()(a_block_window[Base::I0],
b_block_window[Base::I0],
num_loop,
has_hot_loop,
tail_num,
smem_ptr_0,
smem_ptr_1);
}
}();
// Run Epilogue Pipeline
auto& c_block_window = gemm_tile_windows.at(Base::I3);
EpiloguePipeline{}.template
@@ -491,8 +518,9 @@ struct GroupedGemmKernel
const auto gemm_desc_ptr = reinterpret_cast<const GemmTransKernelArg*>(
cast_pointer_to_generic_address_space(gemm_descs_const));
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const index_t group_id = FindGroupId(gemm_desc_ptr, block_id, group_count);
const auto& kargs = gemm_desc_ptr[group_id];
const auto grid_size_2d = TilePartitioner::GridSize(kargs.group_karg.M, kargs.group_karg.N);
const auto block_idx_2d = OffsetTile1DPartitioner::GetOffsetedTileIndex(
0,

View File

@@ -43,7 +43,7 @@ template <bool kPadM_,
bool UseStructuredSparsity_ = false,
bool UsePersistentKernel_ = false,
index_t NumWaveGroups_ = 1,
bool Preshuffle_ = 0>
bool Preshuffle_ = false>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;

View File

@@ -296,6 +296,73 @@ struct UniversalWeightPreshufflePipelineAgBgCrPolicy
WarpGemm>;
return BlockWeightPreshuffleASmemBSmemCRegV1<Problem, BlockWeightPreshufflePolicy>{};
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
using WG_ = typename BlockGemm::WG;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG_::CWarpDstr;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
};
} // namespace ck_tile