Support b_scale: (#2350)

- extend pipeline v1 and v3
 - add instances
 - add tests
 - add example

Co-authored-by: Illia Silin <98187287+illsilin@users.noreply.github.com>
This commit is contained in:
Enrico Degregori
2025-07-25 03:49:58 +02:00
committed by GitHub
parent 2addf05b91
commit b01a27ff22
24 changed files with 3744 additions and 1660 deletions

View File

@@ -91,6 +91,78 @@ struct BlockwiseGemmWmmaops_pipeline_base
true>
c_thread_buf_;
struct Empty
{
__device__ Empty(){};
template <index_t NBuffer>
__device__ void GlobalLoad(bool cond)
{
ignore = NBuffer;
ignore = cond;
}
};
template <index_t ScaleSliceSizeN,
index_t ScaleSliceSizeK,
index_t NWaves,
index_t ScaleBlockK,
index_t NumberOfBuffers,
typename GridDesc,
typename ThreadCopy,
typename GridBuffer,
typename ThreadStaticBuffer,
typename BScaleThreadDesc>
struct BScale
{
__device__ BScale(GridDesc b_scale_grid_desc_,
ThreadCopy b_scale_thread_copy_,
GridBuffer b_scale_grid_buf_)
: b_scale_thread_copy(b_scale_thread_copy_),
b_scale_grid_desc(b_scale_grid_desc_),
b_scale_grid_buf(b_scale_grid_buf_){};
static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block;
static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
static constexpr auto b_scale_thread_copy_step =
make_tuple(make_multi_index(NWaves * NPerWmma, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
template <index_t NBuffer>
__device__ void GlobalLoad(bool cond)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, Number<0>{}),
b_scale_thread_bufs(Number<NBuffer>{}));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
});
if(cond)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
}
}
ThreadCopy b_scale_thread_copy;
GridDesc b_scale_grid_desc;
GridBuffer b_scale_grid_buf;
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> b_scale_thread_bufs;
};
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
__device__ static auto GetWaveIdx()
@@ -285,7 +357,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, MRepeat, 1, 1, 1, A_K1>,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
@@ -296,7 +368,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, NRepeat, 1, 1, 1, B_K1>,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,

View File

@@ -132,6 +132,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
using Base::a_block_desc_k0_m0_m1_m2_k1;
using Base::b_block_desc_k0_n0_n1_n2_k1;
using typename Base::Empty;
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
@@ -158,7 +160,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
typename CThreadBuffer,
typename BScaleStruct>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -172,7 +175,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
// BScaleThreadCopy
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
@@ -186,6 +192,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
@@ -195,20 +203,42 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
auto blockwise_gemm_func = [&]() {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0),
a_thread_buf);
});
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
});
}
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -258,6 +288,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
blockwise_gemm_func();
block_sync_lds();
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
@@ -378,6 +409,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
using Base::a_block_desc_k0_m0_m1_m2_k1;
using Base::b_block_desc_k0_n0_n1_n2_k1;
using typename Base::Empty;
static constexpr index_t NumKClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KRepeatPerCluster = math::max(KRepeat / NumKClusters, 1);
@@ -407,7 +440,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
typename CThreadBuffer,
typename BScaleStruct>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -421,7 +455,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
// BScaleThreadCopy
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
@@ -435,6 +472,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
@@ -445,30 +484,57 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
auto blockwise_gemm_func = [&]() {
static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
I0,
I0,
I0,
I0,
I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0_inner, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
I0,
I0,
I0,
I0,
I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0_inner, I0, I0, I0),
b_thread_buf);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
m0,
I0,
I0,
I0,
I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0_inner, I0, I0, I0),
a_thread_buf);
});
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
n0,
I0,
I0,
I0,
I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0_inner, I0, I0, I0),
b_thread_buf);
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
n0,
I0,
I0,
I0,
I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(I0)[Number<
n0 * BScaleStruct::num_scale_k_block +
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0_inner, I0, I0, I0),
b_thread_buf);
});
}
});
__builtin_amdgcn_sched_barrier(0);
@@ -564,6 +630,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds();
blockwise_gemm_func();
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
@@ -613,7 +680,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, MRepeat, 1, 1, 1, A_K1>,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
A_K1,
@@ -624,7 +691,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, NRepeat, 1, 1, 1, B_K1>,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
B_K1,

View File

@@ -132,6 +132,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
using Base::a_block_desc_k0_m0_m1_m2_k1;
using Base::b_block_desc_k0_n0_n1_n2_k1;
using typename Base::Empty;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
@@ -255,6 +257,58 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
*/
}
template <typename ABlockBuffer,
typename AThreadBuffer,
typename BBlockBuffer,
typename BThreadBuffer,
typename BScaleStruct>
__device__ inline void LocalLoad(ABlockBuffer& a_block_buf,
AThreadBuffer& a_thread_buf,
BBlockBuffer& b_block_buf,
BThreadBuffer& b_thread_buf,
BScaleStruct& b_scale_struct) const
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0),
a_thread_buf);
});
if constexpr(ck::is_same_v<BScaleStruct, Empty>)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
});
}
});
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
@@ -269,7 +323,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer>
typename CThreadBuffer,
typename BScaleStruct>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -283,7 +338,10 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
// BScaleThreadCopy
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
@@ -298,6 +356,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
@@ -314,20 +374,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// Local prefetch 1
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
b_thread_buf);
});
LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
__builtin_amdgcn_sched_barrier(0);
@@ -348,6 +396,8 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
@@ -392,22 +442,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, I0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
a_thread_buf);
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, I0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, I0, k0, I0, I0, I0),
b_thread_buf);
});
LocalLoad(a_block_buf, a_thread_buf, b_block_buf, b_thread_buf, b_scale_struct);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);

View File

@@ -16,6 +16,7 @@
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
@@ -229,222 +230,28 @@ struct DeviceGemm_Wmma_CShuffleV3 : public DeviceGemmV2<ALayout,
using Argument = typename GridwiseGemm::Argument;
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
/// @note If appropriately configured it may measure kernel execution time.
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
{
// TODO: Implement
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return false;
}
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
std::is_same_v<CDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic

View File

@@ -0,0 +1,302 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename BScaleDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockN, // scale block for N
index_t ScaleBlockK, // scale block for K
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
BScaleDataType,
CDataType,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon = DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
ADataType,
BDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
index_t GetKPerBlock() override { return KPerBlock; }
bool GetPermuteB() override { return PermuteB; }
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideScaleB,
const BScaleDataType* p_b_scale,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
return Argument{p_a,
p_b,
p_c,
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideScaleB,
p_b_scale,
KBatch,
a_element_op,
b_element_op,
c_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument> MakeArgumentPointer(const void* p_a,
const void* p_b,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
index_t StrideC,
index_t StrideScaleB,
const void* p_b_scale,
index_t KBatch,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
static_cast<CDataType*>(p_c),
M,
N,
K,
StrideA,
StrideB,
StrideC,
StrideScaleB,
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemm_Wmma_CShuffleV3_BScale"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma<<"x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat<<"x" << NRepeat<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,265 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename GridwiseGemm,
typename ADataType,
typename BDataType,
typename CDataType,
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t BlockSize,
index_t AK1,
index_t BK1,
GemmSpecialization GemmSpec,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB>
struct DeviceGemm_Wmma_CShuffleV3_Common
{
using Argument = typename GridwiseGemm::Argument;
/// @brief Helper structure responsible for kernel invocation.
///
/// @paragraph The `Invoker` class is responsible for preparation and invocation of actual GPU
/// kernel function. It usually determines the launched grid size prepares kernel
/// arguments as well as perform specific kernel configuration selection based on
/// runtime arguments.
///
/// @note If appropriately configured it may measure kernel execution time.
///
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0);
auto size_a_buffer = a_grid_desc_ak0_m_ak1.GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
auto size_b_buffer = b_grid_desc_bk0_n_bk1.GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
ck::utility::RotatingMemWrapper<Argument> rotating_mem(
arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_c_grid,
0,
arg_.M * arg_.N * sizeof(CDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_c_grid,
0,
arg.M * arg.N * sizeof(CDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
{
// TODO: Implement
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
static constexpr bool IsValidCompilationParameter()
{
// TODO: properly implement this check
return true;
}
static bool IsSupportedArgument(const Argument& arg)
{
if(!ck::is_gfx11_supported() && !ck::is_gfx12_supported())
{
return false;
}
if constexpr(std::is_same_v<CDataType, ck::half_t> ||
std::is_same_v<CDataType, ck::bhalf_t>)
{
if(arg.KBatch > 1 && ck::is_gfx11_supported())
{
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
return false;
}
}
if constexpr(std::is_same_v<ComputeTypeA, f8_t> || std::is_same_v<ComputeTypeA, bf8_t> ||
std::is_same_v<ComputeTypeB, f8_t> || std::is_same_v<ComputeTypeB, bf8_t>)
{
if(ck::is_gfx11_supported())
{
return false;
}
}
if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding ||
GemmSpec == GemmSpecialization::NKPadding ||
GemmSpec == GemmSpecialization::MNKPadding ||
GemmSpec == GemmSpecialization::KPadding))
{
return false;
}
return GridwiseGemm::CheckValidity(arg);
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -0,0 +1,551 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/env.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmma_selector.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp"
#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
template <typename ALayout,
typename BLayout,
typename CLayout,
typename ADataType,
typename BDataType,
typename AccDataType,
typename CShuffleDataType,
typename CDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockN, // scale N
index_t ScaleBlockK, // scale K
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1Value,
index_t BK1Value,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool AThreadTransferSrcResetCoordinateAfterRun,
index_t ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BThreadTransferSrcResetCoordinateAfterRun,
index_t BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct GridwiseGemm_wmma_cshuffle_v3_b_scale
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1Value,
BK1Value,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>
{
using BScaleType = ck::half_t;
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
CLayout,
ADataType,
BDataType,
AccDataType,
CShuffleDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
MPerBlock,
NPerBlock,
KPerBlock,
AK1Value,
BK1Value,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
AThreadTransferSrcResetCoordinateAfterRun,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
BThreadTransferSrcResetCoordinateAfterRun,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::I3;
using Base::I4;
using Base::I5;
using Base::I6;
using Base::I7;
using Base::AK0Number;
using Base::AK1Number;
using Base::BK0Number;
using Base::BK1Number;
using Base::APackedSize;
using Base::BPackedSize;
using Base::CalculateAK0Padded;
using Base::CalculateBK0Padded;
using Base::CalculateKPadded;
using Base::CalculateKRead;
using Base::CalculateMBlock;
using Base::CalculateMPadded;
using Base::CalculateNBlock;
using Base::CalculateNPadded;
using Base::MakeAGridDescriptor_AK0_M_AK1;
using Base::MakeBGridDescriptor_BK0_N_BK1;
using Base::MakeCGridDescriptor_M_N;
using Base::GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat;
using Base::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
using Base::GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1;
using Base::GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1;
struct Problem
{
__host__ Problem(index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
N{N_},
K{K_},
StrideA{StrideA_},
StrideB{StrideB_},
StrideC{StrideC_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
KRead{CalculateKRead(K_, KBatch_)},
KPadded{CalculateKPadded(K_, KBatch_)},
AK0{CalculateAK0Padded(K_, KBatch_)},
BK0{CalculateBK0Padded(K_, KBatch_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
{
}
__host__ void Print() const
{
std::cout << "problem {"
<< "M:" << M << ", "
<< "N:" << N << ", "
<< "K:" << K << ", "
<< "SA:" << StrideA << ", "
<< "SB:" << StrideB << ", "
<< "SC:" << StrideC << ", "
<< "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", "
<< "KRead:" << KRead << ", "
<< "KP:" << KPadded << ", "
<< "AK0:" << AK0 << ", "
<< "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", "
<< "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
index_t N;
index_t K;
index_t StrideA;
index_t StrideB;
index_t StrideC;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
index_t NPadded;
index_t KRead;
index_t KPadded;
index_t AK0;
index_t BK0;
index_t MBlock;
index_t NBlock;
};
// Argument
struct Argument : public tensor_operation::device::BaseArgument, public Problem
{
__host__ Argument(const ADataType* p_a_grid_,
const BDataType* p_b_grid_,
CDataType* p_c_grid_,
index_t M_,
index_t N_,
index_t K_,
index_t StrideA_,
index_t StrideB_,
index_t StrideC_,
index_t StrideScaleB_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_,
bool is_reduce_ = false)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideC_, StrideScaleB_, k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_c_grid{p_c_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
c_element_op{c_element_op_},
is_reduce(is_reduce_)
{
}
__host__ __device__ inline bool IsReduceAdd() const
{
return (Problem::KBatch > 1) && is_reduce;
}
__host__ __device__ inline bool IsAtomicAdd() const
{
return (Problem::KBatch > 1) && (!is_reduce);
}
const ADataType* p_a_grid;
const BDataType* p_b_grid;
CDataType* p_c_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
const CElementwiseOperation c_element_op;
bool is_reduce;
};
struct SplitKBatchOffset
{
__device__ SplitKBatchOffset(Argument& karg)
{
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead / APackedSize;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
a_k_split_offset = blockIdx.z * karg.KRead * karg.StrideA;
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
b_k_split_offset = blockIdx.z * karg.KRead * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
if constexpr(!PermuteB)
{
b_k_split_offset = blockIdx.z * karg.KRead / BPackedSize;
}
else
{
const int k0_offset = karg.KRead * karg.N;
b_k_split_offset = blockIdx.z * k0_offset / BPackedSize;
}
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
}
else
{
karg.K = karg.K - karg.KRead * (karg.KBatch - 1);
}
if(karg.IsReduceAdd())
{
c_reduce_offset = blockIdx.z * karg.M * karg.N;
}
else
{
c_reduce_offset = 0;
}
}
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t scale_k_split_offset; // New member for scale matrix offset
index_t c_reduce_offset;
};
using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe;
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK, typename BScaleType>
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const BScaleType* p_b_scale_grid,
index_t block_n_id)
{
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
static constexpr auto wmma =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma;
static constexpr auto ScaleSliceSizeN = NRepeat;
static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma +
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma;
auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
b_thread_offset_k / ScaleBlockK));
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_scale_thread_desc.GetElementSpaceSize());
using BScale =
typename BlockwiseGemmPipe::template BScale<ScaleSliceSizeN,
ScaleSliceSizeK,
NWaves,
ScaleBlockK,
NumberOfBuffers,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_copy),
decltype(b_scale_grid_buf),
decltype(b_scale_thread_buf),
decltype(b_scale_thread_desc)>;
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
}
__device__ static index_t GetKBlockPerScale()
{
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
}
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void Run(const ADataType* p_a_grid,
const BDataType* p_b_grid,
CDataType* p_c_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem)
{
const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideA, problem.AK0);
const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0);
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
c_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(problem.StrideScaleB, 1));
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
const auto block_work_idx =
block_2_ctile_map.CalculateBottomIndex(make_multi_index(get_block_1d_id()));
if(!block_2_ctile_map.ValidCTileIndex(
block_work_idx,
make_tuple(c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I0),
c_grid_desc_mblock_mperblock_nblock_nperblock.GetLength(I2))))
{
return;
}
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// BScale struct
auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id);
const index_t num_k_block_per_scale = GetKBlockPerScale();
Base::template Run<decltype(a_grid_desc_ak0_m_ak1),
decltype(b_grid_desc_bk0_n_bk1),
decltype(c_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(b_scale_struct),
HasMainKBlockLoop,
CGlobalMemoryDataOperation,
TailNum>(p_a_grid,
p_b_grid,
p_c_grid,
p_shared,
a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1,
c_grid_desc_mblock_mperblock_nblock_nperblock,
block_m_id,
block_n_id,
num_k_block_per_scale,
b_scale_struct);
}
// NOTE: Wrapper function to have __global__ function in common
// between gemm_universal, b_scale, ab_scale, etc.
template <bool HasMainKBlockLoop,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
TailNumber TailNum = TailNumber::Odd>
__device__ static void
Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, const Argument& karg)
{
Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_c_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_shared,
karg);
}
};
} // namespace ck

File diff suppressed because it is too large Load Diff