Wmma support for gemm_ab_scale (#3314)

* Support gemm_ab_scale:

 - Add tests
 - Integrate scaling implementation in multiple D
 - Generalize existing b_scale for ab_scale
 - Add instances
 - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK
 - Add support for all layouts supported by xdl
 - Fix splitk xdl

* Fix copyright

* Wmma support for gemm_blockscale_wp (#3315)

* Support for  preshuffle with ab scale

 - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale
 - add support for AScaleLayout amnd BScaleLayout (can be different
   from ALayout and BLayout, respectively)
 - add Run method in v1 pipeline to support preshuffle + scaling
 - add support for preshuffle gemms in common invoker
 - Add splitk support

* Fix copyright header
This commit is contained in:
Enrico Degregori
2025-12-11 09:06:20 +01:00
committed by GitHub
parent d66e5f667c
commit ce99cab605
51 changed files with 5144 additions and 552 deletions

View File

@@ -109,65 +109,145 @@ struct BlockwiseGemmWmmaops_pipeline_base
}
};
template <index_t ScaleSliceSizeN,
template <index_t ScaleSliceSizeMN,
index_t ScaleSliceStrideMN,
index_t ScaleSliceSizeK,
index_t NWaves,
index_t ScaleBlockK,
index_t NumberOfBuffers,
index_t RegSizePerWmma,
typename GridDesc,
typename ThreadCopy,
typename GridBuffer,
typename ThreadStaticBuffer,
typename BScaleThreadDesc>
struct BScale
typename ThreadDesc>
struct ABScale
{
__device__ BScale(GridDesc b_scale_grid_desc_,
ThreadCopy b_scale_thread_copy_,
GridBuffer b_scale_grid_buf_)
: b_scale_thread_copy(b_scale_thread_copy_),
b_scale_grid_desc(b_scale_grid_desc_),
b_scale_grid_buf(b_scale_grid_buf_) {};
__device__ ABScale(GridDesc scale_grid_desc_,
ThreadCopy scale_thread_copy_,
GridBuffer scale_grid_buf_)
: scale_thread_copy(scale_thread_copy_),
scale_grid_desc(scale_grid_desc_),
scale_grid_buf(scale_grid_buf_) {};
static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{});
static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block;
static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
static constexpr index_t num_slice_mn = ScaleSliceSizeMN;
static constexpr index_t num_slice_k = ScaleSliceSizeK;
static constexpr index_t reg_size_per_wmma = RegSizePerWmma;
static constexpr auto b_scale_thread_copy_step =
make_tuple(make_multi_index(NWaves * NPerWmma, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
static constexpr auto scale_thread_desc = ThreadDesc{};
static constexpr auto scale_thread_copy_step =
make_tuple(make_multi_index(ScaleSliceStrideMN, 0),
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0),
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN,
ScaleSliceSizeK));
template <index_t NBuffer>
__device__ void GlobalLoad(bool cond)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, Number<0>{}),
b_scale_thread_bufs(Number<NBuffer>{}));
static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) {
scale_thread_copy.Run(scale_grid_desc,
scale_grid_buf,
scale_thread_desc,
make_tuple(m0 * Number<RegSizePerWmma>{}, Number<0>{}),
scale_thread_bufs(Number<NBuffer>{}));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<0>{}));
});
if(cond)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<1>{}));
}
}
ThreadCopy b_scale_thread_copy;
GridDesc b_scale_grid_desc;
GridBuffer b_scale_grid_buf;
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> b_scale_thread_bufs;
ThreadCopy scale_thread_copy;
GridDesc scale_grid_desc;
GridBuffer scale_grid_buf;
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> scale_thread_bufs;
};
template <typename AScaleStruct, typename BScaleStruct>
struct CScale
{
__device__ CScale() {}
static constexpr auto reg_size_per_wmma =
ck::is_same_v<BScaleStruct, Empty> && ck::is_same_v<AScaleStruct, Empty>
? 1
: wmma_gemm.GetRegSizePerWmma();
static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{},
Number<AScaleStruct::num_slice_mn>{},
Number<BScaleStruct::num_slice_mn>{}));
using CScaleThreadDesc = decltype(c_scale_thread_desc);
static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
using ThreadStaticBuffer = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
c_scale_thread_desc.GetElementSpaceSize()));
__device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct)
{
using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc);
using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc);
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_bufs(I0)(Number<c_offset>{}) =
a_scale_struct.scale_thread_bufs(I0)[Number<a_offset>{}] *
b_scale_struct.scale_thread_bufs(I0)[Number<b_offset>{}];
});
});
});
}
__device__ void Clear()
{
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
}
template <index_t k_index, index_t m_index, index_t n_index, typename CThreadBuf>
__device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf)
{
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple(
k_index,
(m_index * num_scale_m_block / MRepeat) % num_scale_m_block +
(Number<t / (reg_size_per_wmma / AScaleStruct::reg_size_per_wmma)>{}) %
AScaleStruct::reg_size_per_wmma,
(n_index * num_scale_n_block / NRepeat) % num_scale_n_block));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_bufs(I0)[Number<cscale_offset>{}]);
});
}
StaticallyIndexedArray<ThreadStaticBuffer, Number<1>{}> c_scale_thread_bufs;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, AccDataType, 1, reg_size_per_wmma, true>
c_thread_buf_per_scale;
};
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }

View File

@@ -174,7 +174,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -188,7 +190,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -207,6 +209,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -217,6 +220,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
c_thread_buf.Clear();
auto blockwise_gemm_func = [&]() {
// Local load
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
@@ -245,7 +249,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
b_scale_struct.scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -366,6 +370,189 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
Base::a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
Base::b_thread_desc_.GetElementSpaceSize());
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Initialize C
c_thread_buf.Clear();
auto blockwise_gemm_func = [&]() {
// Local load
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
Base::a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
Base::b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
};
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
block_sync_lds();
blockwise_gemm_func();
block_sync_lds();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
blockwise_gemm_func();
}
}
protected:
// A[MRepeat, I1, I1, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
@@ -528,6 +715,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
return TailNumber::Full;
}
template <typename AScaleStruct, typename BScaleStruct>
struct KLoopParams
{
static constexpr auto KRepeatNoScale = 1;
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
};
template <>
struct KLoopParams<Empty, Empty>
{
static constexpr index_t KRepeatNoScale = KRepeatPerCluster;
static constexpr index_t NumScaleKBlock = 1;
static constexpr index_t KRepeatPerNumScaleKBlock = 1;
};
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
@@ -543,7 +747,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -557,7 +763,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -576,6 +782,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -615,7 +822,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(I0)[Number<
b_scale_struct.scale_thread_bufs(I0)[Number<
n0 * BScaleStruct::num_scale_k_block +
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -704,6 +911,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
@@ -982,7 +1190,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -996,7 +1206,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer&,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct&,
index_t num_loop,
index_t) const
@@ -1319,6 +1529,248 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc&,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer&,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
auto gemm_core_func = [&](auto reg_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k_index,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
};
auto a_local_prefetch_func = [&]() {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
});
};
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
__builtin_amdgcn_sched_barrier(0);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1
block_sync_lds();
a_local_prefetch_func();
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
a_scale_struct.template GlobalLoad<0>(
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
gemm_core_func(wmma_reg_buf);
block_sync_lds();
// loop prefetch copy
a_local_prefetch_func();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
gemm_core_func(I0);
block_sync_lds();
// tail Local Prefetch A1
a_local_prefetch_func();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
__builtin_amdgcn_sched_barrier(0);
gemm_core_func(I1);
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == TailNumber::Odd)
{
gemm_core_func(I0);
}
}
protected:
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<KPack / B_K1 / B_KRow>{},

View File

@@ -123,6 +123,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
KInner,
TransposeC>;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::I3;
using Base::A_K1;
using Base::A_KRow;
@@ -322,7 +325,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
b_scale_struct.scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -348,7 +351,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -362,7 +367,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -383,6 +388,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -611,6 +617,339 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2, perform when at least 2 loops exist.
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
// Initialize C
c_thread_buf.Clear();
// Local prefetch 1
block_sync_lds();
auto local_load_func = [&]() {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
});
};
local_load_func();
__builtin_amdgcn_sched_barrier(0);
// Main body, perform when at least 3 loops exist.
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale
.GetVectorTypeReference(Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
c_scale_struct.Load(a_scale_struct, b_scale_struct);
block_sync_lds();
local_load_func();
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 2));
}
// Pre-tail, perform when at least 2 loops exist.
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// No RunRead or MoveSrcSliceWindow here, already finished them all!
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
c_scale_struct.Load(a_scale_struct, b_scale_struct);
block_sync_lds();
local_load_func();
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// Tail, always perform.
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
}
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;