Remove code duplications in batched gemm wmma (#3580)

* Moved device struct for batched gemm wmma to a common file

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

* Use the common device struct in the scaled batched gemm wmma implementation

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

* Boy-scout: Remove unused includes and ambiguous comment

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

* Moved pointer offset calculation and gridwise argument to common struct

This change enables further code reduction by re-using the common structs for the batched gemm and batched gemm b scale wmma implementations.

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

* Moved type string to the common struct of DeviceBatchedGemm_Wmma_CShuffleV3_Common"

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>

---------

Signed-off-by: Chris Tsiaousis <chris.tsiaousis@streamhpc.com>
This commit is contained in:
chris-tsiaousis-hpc
2026-01-23 21:39:03 +01:00
committed by GitHub
parent 67f0b74ec6
commit e1c46ff548
4 changed files with 719 additions and 977 deletions

View File

@@ -77,6 +77,122 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
}
template <typename GridwiseGemm,
typename ComputePtrOffsetOfStridedBatch,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
bool IsBScaled = false,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_batched_gemm_wmma_cshuffle_v3(
typename GridwiseGemm::Argument karg, // This works for now but it actually receives a
// DeviceBatchedGemm_Wmma_CShuffleV3::Argument
// argument through implicit conversion to base class!
const ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using c_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(CGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<c_data_type, ck::half_t> ||
std::is_same_v<c_data_type, ck::bhalf_t>)))
{
#endif
// The normal approach to batching would be to increase the grid size by just stretching out
// the grid Z dimension (which is the outermost dimension), but this depends on lower level
// functions not directly using the Z dimension for other calculations. As it turns out, k
// batching does rely directly on blockIdx.Z through SplitKBatchOffset. Therefore, for now
// we will use the grid Y dimension for batching. This may be a bit fragile.
const index_t g_idx = amd_wave_read_first_lane(blockIdx.y);
const long_index_t a_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx));
const long_index_t b_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx));
const long_index_t c_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx));
using EpilogueType =
typename std::conditional<GridwiseGemm::IsBWaveTransferApplicable &&
GridwiseGemm::UseDirectStore,
typename GridwiseGemm::EpilogueDirectStore,
typename GridwiseGemm::EpilogueCShuffle>::type;
constexpr index_t LDS_size =
GridwiseGemm::template GetSharedMemoryNumberOfByte<EpilogueType>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
// shift A matrices pointer for splitk
typename GridwiseGemm::AsGridPointer p_as_grid_shift;
static_for<0, GridwiseGemm::NumATensor, 1>{}([&](auto i) {
using ADataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::AsDataType_>>;
p_as_grid_shift(i) = static_cast<const ADataType_*>(karg.p_as_grid[i]) +
splitk_batch_offset.a_k_split_offset[i] + a_batch_offset;
});
// shift B matrices pointer for splitk
typename GridwiseGemm::BsGridPointer p_bs_grid_shift;
static_for<0, GridwiseGemm::NumBTensor, 1>{}([&](auto i) {
using BDataType_ =
remove_cvref_t<tuple_element_t<i.value, typename GridwiseGemm::BsDataType_>>;
p_bs_grid_shift(i) = static_cast<const BDataType_*>(karg.p_bs_grid[i]) +
splitk_batch_offset.b_k_split_offset[i] + b_batch_offset;
});
auto epilogue_args = EpilogueType{};
if constexpr(IsBScaled)
{
const long_index_t b_scale_batch_offset =
amd_wave_read_first_lane(compute_ptr_offset_of_batch.GetScaleBPtrOffset(g_idx));
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
karg.p_a_scale_grid,
karg.p_b_scale_grid + b_scale_batch_offset +
splitk_batch_offset.scale_b_k_split_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
else
{
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
p_as_grid_shift,
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
}
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
ignore = compute_ptr_offset_of_batch;
#endif
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,