Wmma support for gemm_ab_scale (#3314)

* Support gemm_ab_scale:

 - Add tests
 - Integrate scaling implementation in multiple D
 - Generalize existing b_scale for ab_scale
 - Add instances
 - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK
 - Add support for all layouts supported by xdl
 - Fix splitk xdl

* Fix copyright

* Wmma support for gemm_blockscale_wp (#3315)

* Support for  preshuffle with ab scale

 - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale
 - add support for AScaleLayout amnd BScaleLayout (can be different
   from ALayout and BLayout, respectively)
 - add Run method in v1 pipeline to support preshuffle + scaling
 - add support for preshuffle gemms in common invoker
 - Add splitk support

* Fix copyright header
This commit is contained in:
Enrico Degregori
2025-12-11 09:06:20 +01:00
committed by GitHub
parent d66e5f667c
commit ce99cab605
51 changed files with 5144 additions and 552 deletions

View File

@@ -109,65 +109,145 @@ struct BlockwiseGemmWmmaops_pipeline_base
}
};
template <index_t ScaleSliceSizeN,
template <index_t ScaleSliceSizeMN,
index_t ScaleSliceStrideMN,
index_t ScaleSliceSizeK,
index_t NWaves,
index_t ScaleBlockK,
index_t NumberOfBuffers,
index_t RegSizePerWmma,
typename GridDesc,
typename ThreadCopy,
typename GridBuffer,
typename ThreadStaticBuffer,
typename BScaleThreadDesc>
struct BScale
typename ThreadDesc>
struct ABScale
{
__device__ BScale(GridDesc b_scale_grid_desc_,
ThreadCopy b_scale_thread_copy_,
GridBuffer b_scale_grid_buf_)
: b_scale_thread_copy(b_scale_thread_copy_),
b_scale_grid_desc(b_scale_grid_desc_),
b_scale_grid_buf(b_scale_grid_buf_) {};
__device__ ABScale(GridDesc scale_grid_desc_,
ThreadCopy scale_thread_copy_,
GridBuffer scale_grid_buf_)
: scale_thread_copy(scale_thread_copy_),
scale_grid_desc(scale_grid_desc_),
scale_grid_buf(scale_grid_buf_) {};
static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{});
static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block;
static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
static constexpr index_t num_slice_mn = ScaleSliceSizeMN;
static constexpr index_t num_slice_k = ScaleSliceSizeK;
static constexpr index_t reg_size_per_wmma = RegSizePerWmma;
static constexpr auto b_scale_thread_copy_step =
make_tuple(make_multi_index(NWaves * NPerWmma, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
static constexpr auto scale_thread_desc = ThreadDesc{};
static constexpr auto scale_thread_copy_step =
make_tuple(make_multi_index(ScaleSliceStrideMN, 0),
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0),
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN,
ScaleSliceSizeK));
template <index_t NBuffer>
__device__ void GlobalLoad(bool cond)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, Number<0>{}),
b_scale_thread_bufs(Number<NBuffer>{}));
static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) {
scale_thread_copy.Run(scale_grid_desc,
scale_grid_buf,
scale_thread_desc,
make_tuple(m0 * Number<RegSizePerWmma>{}, Number<0>{}),
scale_thread_bufs(Number<NBuffer>{}));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<0>{}));
});
if(cond)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<1>{}));
}
}
ThreadCopy b_scale_thread_copy;
GridDesc b_scale_grid_desc;
GridBuffer b_scale_grid_buf;
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> b_scale_thread_bufs;
ThreadCopy scale_thread_copy;
GridDesc scale_grid_desc;
GridBuffer scale_grid_buf;
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> scale_thread_bufs;
};
template <typename AScaleStruct, typename BScaleStruct>
struct CScale
{
__device__ CScale() {}
static constexpr auto reg_size_per_wmma =
ck::is_same_v<BScaleStruct, Empty> && ck::is_same_v<AScaleStruct, Empty>
? 1
: wmma_gemm.GetRegSizePerWmma();
static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{},
Number<AScaleStruct::num_slice_mn>{},
Number<BScaleStruct::num_slice_mn>{}));
using CScaleThreadDesc = decltype(c_scale_thread_desc);
static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
using ThreadStaticBuffer = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
c_scale_thread_desc.GetElementSpaceSize()));
__device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct)
{
using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc);
using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc);
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_bufs(I0)(Number<c_offset>{}) =
a_scale_struct.scale_thread_bufs(I0)[Number<a_offset>{}] *
b_scale_struct.scale_thread_bufs(I0)[Number<b_offset>{}];
});
});
});
}
__device__ void Clear()
{
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
}
template <index_t k_index, index_t m_index, index_t n_index, typename CThreadBuf>
__device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf)
{
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple(
k_index,
(m_index * num_scale_m_block / MRepeat) % num_scale_m_block +
(Number<t / (reg_size_per_wmma / AScaleStruct::reg_size_per_wmma)>{}) %
AScaleStruct::reg_size_per_wmma,
(n_index * num_scale_n_block / NRepeat) % num_scale_n_block));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_bufs(I0)[Number<cscale_offset>{}]);
});
}
StaticallyIndexedArray<ThreadStaticBuffer, Number<1>{}> c_scale_thread_bufs;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, AccDataType, 1, reg_size_per_wmma, true>
c_thread_buf_per_scale;
};
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }

View File

@@ -174,7 +174,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -188,7 +190,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -207,6 +209,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -217,6 +220,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
c_thread_buf.Clear();
auto blockwise_gemm_func = [&]() {
// Local load
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
@@ -245,7 +249,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
b_scale_struct.scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -366,6 +370,189 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
Base::a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
Base::b_thread_desc_.GetElementSpaceSize());
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Initialize C
c_thread_buf.Clear();
auto blockwise_gemm_func = [&]() {
// Local load
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
Base::a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
Base::b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
};
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
block_sync_lds();
blockwise_gemm_func();
block_sync_lds();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
blockwise_gemm_func();
}
}
protected:
// A[MRepeat, I1, I1, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
@@ -528,6 +715,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
return TailNumber::Full;
}
template <typename AScaleStruct, typename BScaleStruct>
struct KLoopParams
{
static constexpr auto KRepeatNoScale = 1;
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
};
template <>
struct KLoopParams<Empty, Empty>
{
static constexpr index_t KRepeatNoScale = KRepeatPerCluster;
static constexpr index_t NumScaleKBlock = 1;
static constexpr index_t KRepeatPerNumScaleKBlock = 1;
};
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
@@ -543,7 +747,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -557,7 +763,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -576,6 +782,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -615,7 +822,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(I0)[Number<
b_scale_struct.scale_thread_bufs(I0)[Number<
n0 * BScaleStruct::num_scale_k_block +
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -704,6 +911,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
@@ -982,7 +1190,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -996,7 +1206,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer&,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct&,
index_t num_loop,
index_t) const
@@ -1319,6 +1529,248 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc&,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer&,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
auto gemm_core_func = [&](auto reg_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k_index,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
};
auto a_local_prefetch_func = [&]() {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
});
};
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
__builtin_amdgcn_sched_barrier(0);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1
block_sync_lds();
a_local_prefetch_func();
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
a_scale_struct.template GlobalLoad<0>(
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
gemm_core_func(wmma_reg_buf);
block_sync_lds();
// loop prefetch copy
a_local_prefetch_func();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
gemm_core_func(I0);
block_sync_lds();
// tail Local Prefetch A1
a_local_prefetch_func();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
__builtin_amdgcn_sched_barrier(0);
gemm_core_func(I1);
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == TailNumber::Odd)
{
gemm_core_func(I0);
}
}
protected:
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<KPack / B_K1 / B_KRow>{},

View File

@@ -123,6 +123,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
KInner,
TransposeC>;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::I3;
using Base::A_K1;
using Base::A_KRow;
@@ -322,7 +325,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
b_scale_struct.scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -348,7 +351,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -362,7 +367,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -383,6 +388,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -611,6 +617,339 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2, perform when at least 2 loops exist.
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
// Initialize C
c_thread_buf.Clear();
// Local prefetch 1
block_sync_lds();
auto local_load_func = [&]() {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
});
};
local_load_func();
__builtin_amdgcn_sched_barrier(0);
// Main body, perform when at least 3 loops exist.
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale
.GetVectorTypeReference(Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
c_scale_struct.Load(a_scale_struct, b_scale_struct);
block_sync_lds();
local_load_func();
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 2));
}
// Pre-tail, perform when at least 2 loops exist.
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// No RunRead or MoveSrcSliceWindow here, already finished them all!
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
c_scale_struct.Load(a_scale_struct, b_scale_struct);
block_sync_lds();
local_load_func();
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// Tail, always perform.
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
}
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;

View File

@@ -105,6 +105,353 @@ struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator
virtual int GetPreShuffleParameters() = 0;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK in contexts where
// DeviceGemmMultipleD_BlockScale_BPreshuffle is expected.
///
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
/// The only difference between API of DeviceGemmMultipleD_BlockScale_BPreshuffle and
// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK is
/// that DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK::MakeArgumentPointer requires
// an additional parameter KBatch which is explicitly passed as 1 by this wrapper.
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
AScaleType,
BDataType,
BScaleType,
DsDataType,
EDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
AScaleType,
BDataType,
BScaleType,
DsDataType,
EDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
explicit DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper(std::unique_ptr<DeviceOp> p_op)
: p_op_(std::move(p_op))
{
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return p_op_->IsSupportedArgument(p_arg);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return p_op_->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
p_a_scale,
p_b_scale,
a_element_op,
b_element_op,
cde_element_op,
1); // KBatch
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return p_op_->MakeInvokerPointer();
}
int GetPreShuffleParameters() override { return p_op_->GetPreShuffleParameters(); }
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
private:
std::unique_ptr<DeviceOp> p_op_;
#endif // __HIPCC_RTC__
};
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_ABScaleSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0;
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleD_ABScaleSplitK in contexts where DeviceGemmMultipleD_ABScale is
/// expected.
///
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
/// The only difference between API of DeviceGemmMultipleD_ABScale and
/// DeviceGemmMultipleD_ABScaleSplitK is that
/// DeviceGemmMultipleD_ABScaleSplitK::MakeArgumentPointer requires a additional parameter
/// KBatch which is explicitly passed as 1 by this wrapper.
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_ABScaleSplitKWrapper
: public DeviceGemmMultipleD_ABScale<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
AScaleType,
BDataType,
BScaleType,
DsDataType,
EDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
AScaleType,
BDataType,
BScaleType,
DsDataType,
EDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
explicit DeviceGemmMultipleD_ABScaleSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
: p_op_(std::move(p_op))
{
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return p_op_->IsSupportedArgument(p_arg);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return p_op_->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
p_a_scale,
p_b_scale,
a_element_op,
b_element_op,
cde_element_op,
1); // KBatch
}
void SetKBatch(BaseArgument* arg, int KBatch) const override { p_op_->SetKBatch(arg, KBatch); }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return p_op_->MakeInvokerPointer();
}
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
private:
std::unique_ptr<DeviceOp> p_op_;
#endif // __HIPCC_RTC__
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
@@ -93,7 +93,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
karg.p_a_scale_grid,
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset,
p_shared,
karg,
karg.a_element_op,
@@ -315,12 +316,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
Tuple<ADataType>,
void, // AScaleType
Tuple<BDataType>,
BScaleDataType,
AccDataType,
@@ -332,6 +334,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
CElementwiseOperation,
GemmSpec,
BlockSize,
0, // ScaleBlockM
ScaleBlockN,
ScaleBlockK,
MPerBlock,
@@ -405,7 +408,9 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
std::array<index_t, 1>{StrideB_},
std::array<index_t, 0>{}, // StrideDs_
StrideC_,
0, // StrideScaleA
StrideScaleB_,
nullptr,
p_b_scale_grid_,
k_batch_,
a_element_op_,

View File

@@ -0,0 +1,362 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM, // scale block for M
index_t ScaleBlockN, // scale block for N
index_t ScaleBlockK, // scale block for K
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3
: public DeviceGemmMultipleD_ABScaleSplitK<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
ALayout,
BLayout,
DsLayout,
CLayout,
Tuple<ADataType>,
AScaleDataType,
Tuple<BDataType>,
BScaleDataType,
AccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
Tuple<ADataType>,
Tuple<BDataType>,
DsDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
CShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
// with splitk the implementation doesn't work
// when KRead % ScaleBlockK != 0, independently of K padding
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
{
return false;
}
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
void SetKBatch(BaseArgument* base_arg, int KBatch) const override
{
auto& arg = *dynamic_cast<Argument*>(base_arg);
arg.KBatch = KBatch;
arg.KRead = GridwiseGemm::CalculateKRead(arg.K, KBatch);
arg.KPadded = GridwiseGemm::CalculateKPadded(arg.K, KBatch);
arg.AK0 = GridwiseGemm::CalculateAK0Padded(arg.K, KBatch);
arg.BK0 = GridwiseGemm::CalculateBK0Padded(arg.K, KBatch);
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
std::array<const void*, NumDTensor> p_ds,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
const BScaleDataType* p_a_scale,
const BScaleDataType* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op,
index_t KBatch = 1)
{
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return Argument{std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
p_c,
M,
N,
K,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
p_a_scale,
p_b_scale,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch = 1) override
{
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemm_ABScale_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma<<"x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat<<"x" << NRepeat<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -18,52 +18,6 @@
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
const index_t k_id = blockIdx.z * num_k_per_block;
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
@@ -202,270 +156,14 @@ struct DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
ComputeTypeB,
true>; // IsBPreshuffle
// Invoker
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
std::array<std::size_t, 1> size_as_buffers;
size_as_buffers[Number<0>{}] =
a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
std::array<std::size_t, 1> size_bs_buffers;
size_bs_buffers[Number<0>{}] =
b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
size_ds_buffers[i] =
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiABD<Argument,
Tuple<ADataType>,
Tuple<BDataType>,
DsDataType>
rotating_mem(arg_,
stream_config.rotating_count,
size_as_buffers,
size_bs_buffers,
size_ds_buffers);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
0,
arg_.M * arg_.N * sizeof(EDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
0,
arg.M * arg.N * sizeof(EDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
// ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is
// currently implemented in such a way that all SrcScalarPerVectors must be the same, so
// if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
// be odd.
constexpr bool AtomicsImplementationExists =
!(std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t> ||
std::is_same_v<EDataType, int8_t>) ||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
return DeviceGemmCommon::IsSupportedArgument(arg);
}

View File

@@ -0,0 +1,360 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM, // scale block for M
index_t ScaleBlockN, // scale block for N
index_t ScaleBlockK, // scale block for K
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle
: public DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using AScaleLayout = tensor_layout::gemm::ColumnMajor;
using BScaleLayout = BLayout;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
ALayout,
BLayout,
DsLayout,
CLayout,
Tuple<ADataType>,
AScaleDataType,
Tuple<BDataType>,
BScaleDataType,
AccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB,
true,
AScaleLayout,
BScaleLayout>;
using Argument = typename GridwiseGemm::Argument;
int GetPreShuffleParameters() override { return NPerWmma; }
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
Tuple<ADataType>,
Tuple<BDataType>,
DsDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
CShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
true>; // IsBPreshuffle
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
// with splitk the implementation doesn't work
// when KRead % ScaleBlockK != 0, independently of K padding
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
{
return false;
}
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op,
index_t KBatch)
{
index_t StrideScaleA = ck::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BScaleLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return Argument{std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<CDataType*>(p_e),
M,
N,
K,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch) override
{
index_t StrideScaleA = ck::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BScaleLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<CDataType*>(p_e),
M,
N,
K,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma<<"x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat<<"x" << NRepeat<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -262,6 +262,16 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
@@ -279,22 +289,47 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
}
}
}
@@ -315,6 +350,20 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
{
auto& arg = *dynamic_cast<Argument*>(base_arg);
arg.KBatch = KBatch;
if(get_warp_size() == 64)
{
arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch);
arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch);
arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch);
arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch);
}
else
{
arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch);
arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch);
arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch);
arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch);
}
}
static constexpr bool IsValidCompilationParameter()
@@ -325,6 +374,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
static bool IsSupportedArgument(const Argument& arg)
{
// with splitk the implementation doesn't work
// when KRead % ScaleBlockK != 0, independently of K padding
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
{
return false;
}
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;
@@ -385,6 +441,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
@@ -396,6 +460,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
StrideB,
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1,
@@ -425,6 +491,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
@@ -436,6 +510,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
StrideB,
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1,

View File

@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
@@ -86,12 +86,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
Tuple<ADataType>,
void, // AScaleType
Tuple<BDataType>,
BScaleDataType,
AccDataType,
@@ -103,6 +104,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
CElementwiseOperation,
GemmSpec,
BlockSize,
0, // ScaleBlockM
ScaleBlockN,
ScaleBlockK,
MPerBlock,
@@ -207,7 +209,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
std::array<index_t, 1>{StrideB},
std::array<index_t, 0>{}, // StrideDs_
StrideC,
0, // StrideScaleA
StrideScaleB,
nullptr,
p_b_scale,
KBatch,
a_element_op,
@@ -245,7 +249,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
std::array<index_t, 1>{StrideB},
std::array<index_t, 0>{}, // StrideDs_
StrideC,
0, // StrideScaleA
StrideScaleB,
nullptr, // p_a_scale
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,

View File

@@ -38,7 +38,8 @@ template <typename GridwiseGemm,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB>
typename ComputeTypeB,
bool IsBPreShuffled = false>
struct DeviceGemm_Wmma_CShuffleV3_Common
{
@@ -189,61 +190,174 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
if constexpr(IsBPreShuffled)
{
if(arg.KBatch > 1)
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if constexpr(AtomicsImplementationExists)
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
if constexpr(AtomicsImplementationExists)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
{
// TODO: Implement
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if constexpr(AtomicsImplementationExists)
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
const auto kernel = kernel_gemm_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
}
}
else
{
if constexpr(IsBPreShuffled)
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
const auto kernel = kernel_gemm_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
}
@@ -299,6 +413,14 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
return false;
}
if constexpr(IsBPreShuffled)
{
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg);
}
};

View File

@@ -388,11 +388,11 @@ struct ABTransferThreadTiles
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(
make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow, Number<1>{})),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(make_unmerge_transform(make_tuple(
Number<ABK0 / KRow>{}, KRow, Number<KPack / KRow / ABK1>{})),
make_unmerge_transform(make_tuple(
Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
}

View File

@@ -895,8 +895,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
c_thread_buf.Clear();
// Empty BScale struct for the blockwise pipeline.
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
using ABScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = ABScale{};
auto b_scale_struct = ABScale{};
/*******************************************************************************/
//
@@ -919,6 +920,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
b0_block_buf,
b0_block_slice_copy_step,
acc0_thread_buf,
a_scale_struct,
b_scale_struct,
KBlockMainLoop,
1); // num_k_block_per_scale

View File

@@ -618,8 +618,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
using Scale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = Scale{};
auto b_scale_struct = Scale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
@@ -627,6 +628,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
@@ -646,6 +648,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_id);

View File

@@ -23,6 +23,7 @@ template <typename ALayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename AScaleType,
typename BsDataType,
typename BScaleType,
typename AccDataType,
@@ -34,6 +35,7 @@ template <typename ALayout,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM,
index_t ScaleBlockN, // scale N
index_t ScaleBlockK, // scale K
index_t MPerBlock,
@@ -65,13 +67,16 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB,
bool PermuteA,
bool PermuteB,
bool IsBPreShuffled = false,
typename AScaleLayout = ALayout,
typename BScaleLayout = BLayout>
struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
@@ -123,7 +128,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
ComputeTypeB,
PermuteA,
PermuteB,
false,
IsBPreShuffled,
true>
{
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
@@ -177,7 +182,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
ComputeTypeB,
PermuteA,
PermuteB,
false,
IsBPreShuffled,
true>;
using Base::I0;
@@ -233,6 +238,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleA_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
@@ -242,6 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
StrideBs{StrideBs_},
StrideDs{StrideDs_},
StrideE{StrideE_},
StrideScaleA{StrideScaleA_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
@@ -251,7 +258,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
AK0{CalculateAK0Padded(K_, KBatch_)},
BK0{CalculateBK0Padded(K_, KBatch_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
NBlock{CalculateNBlock(N_)},
Kt{K_}
{
}
@@ -275,11 +283,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
<< std::endl;
std::cout << "SE:" << StrideE << ", " << "SScaleA:" << StrideScaleA << ", "
<< "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
@@ -289,6 +297,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t StrideScaleA;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
@@ -299,6 +308,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t BK0;
index_t MBlock;
index_t NBlock;
index_t Kt;
};
// Argument
@@ -315,7 +325,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleA_,
index_t StrideScaleB_,
const AScaleType* p_a_scale_grid_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
@@ -329,12 +341,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
StrideBs_,
StrideDs_,
StrideE_,
StrideScaleA_,
StrideScaleB_,
k_batch_},
p_as_grid{},
p_bs_grid{},
p_ds_grid{},
p_e_grid{p_e_grid_},
p_a_scale_grid{p_a_scale_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
@@ -379,6 +393,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AScaleType* p_a_scale_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
@@ -407,34 +422,52 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
if constexpr(IsBPreShuffled)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
else
{
if constexpr(!PermuteB)
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
static_for<0, NumBTensor, 1>{}([&](auto i) {
b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i];
});
}
else
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
const int k0_offset = karg.KRead * karg.N;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
if constexpr(!PermuteB)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
}
else
{
const int k0_offset = karg.KRead * karg.N;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
}
}
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, AScaleLayout>)
{
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, AScaleLayout>)
{
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BScaleLayout>)
{
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BScaleLayout>)
{
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
if(k_id < karg.KBatch - 1)
@@ -458,77 +491,225 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumATensor> a_k_split_offset;
std::array<index_t, NumBTensor> b_k_split_offset;
index_t scale_k_split_offset; // New member for scale matrix offset
index_t scale_a_k_split_offset; // A scale matrix offset
index_t scale_b_k_split_offset; // B scale matrix offset
index_t c_reduce_offset;
};
using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe;
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const BScaleType* p_b_scale_grid,
index_t block_n_id)
__device__ static constexpr auto
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
{
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::RowMajor, AScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
}
}
static constexpr auto wmma =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma;
template <index_t NumberOfBuffers>
__device__ static auto
MakeAScale(const Problem& problem, const AScaleType* p_a_scale_grid, index_t block_m_id)
{
if constexpr(ck::is_same_v<AScaleType, void>)
{
using AScale = typename BlockwiseGemmPipe::Empty;
return AScale{};
}
else
{
#if defined(__gfx11__)
// TODO: remove this restriction
static_assert(ScaleBlockM >= MPerWmma,
"ScaleBlockM must be greater equal than MPerWmma");
#endif
static_assert(
ScaleBlockK >=
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
selected_wmma.k_per_wmma,
"ScaleBlockK must be greater equal than KPerWmma");
static constexpr auto ScaleSliceSizeN = NRepeat;
static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
const auto a_scale_grid_desc_am_ak =
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
constexpr auto wmma =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
constexpr auto RegSizePerWmmaFull =
wmma.selected_wmma.num_acc_vgprs_per_wave * wmma.selected_wmma.acc_pack_number;
constexpr auto RegSizePerWmma =
math::integer_divide_ceil(RegSizePerWmmaFull, ScaleBlockM);
auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma +
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma;
auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread;
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
b_thread_offset_k / ScaleBlockK));
constexpr auto ScaleSliceSizeM =
ScaleBlockM < MPerWmma ? MRepeat * RegSizePerWmma
: math::integer_divide_ceil(MPerBlock, ScaleBlockM);
constexpr auto ScaleSliceStrideM =
math::integer_divide_ceil(MWaves * MPerWmma, ScaleBlockM);
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_scale_thread_desc.GetElementSpaceSize());
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
using BScale =
typename BlockwiseGemmPipe::template BScale<ScaleSliceSizeN,
ScaleSliceSizeK,
NWaves,
ScaleBlockK,
NumberOfBuffers,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_copy),
decltype(b_scale_grid_buf),
decltype(b_scale_thread_buf),
decltype(b_scale_thread_desc)>;
auto a_thread_offset_m =
((get_thread_local_1d_id() % 32) / MPerWmma * RegSizePerWmma) /
math::integer_divide_ceil(ScaleBlockM, RegSizePerWmmaFull) +
(get_thread_local_1d_id() / 32) / NWaves * MPerWmma / ScaleBlockM;
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
constexpr index_t VectorDim =
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? 0 : 1;
constexpr index_t VectorSize =
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? RegSizePerWmma
: ScaleSliceSizeK;
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleType,
AScaleType,
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc),
Sequence<RegSizePerWmma, ScaleSliceSizeK>,
Sequence<0, 1>,
VectorDim,
VectorSize,
1,
true>(
a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset_m, 0));
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleType>(
a_scale_thread_desc.GetElementSpaceSize());
using AScale =
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeM,
ScaleSliceStrideM,
ScaleSliceSizeK,
NumberOfBuffers,
RegSizePerWmma,
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_copy),
decltype(a_scale_grid_buf),
decltype(a_scale_thread_buf),
decltype(a_scale_thread_desc)>;
return AScale{a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf};
}
}
__device__ static constexpr auto
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
{
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
}
}
template <index_t NumberOfBuffers>
__device__ static auto
MakeBScale(const Problem& problem, const BScaleType* p_b_scale_grid, index_t block_n_id)
{
if constexpr(ck::is_same_v<BScaleType, void>)
{
using BScale = typename BlockwiseGemmPipe::Empty;
return BScale{};
}
else
{
static_assert(
ScaleBlockK >=
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
selected_wmma.k_per_wmma,
"ScaleBlockK must be greater equal than KPerWmma");
const auto b_scale_grid_desc_bn_ak =
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
constexpr auto ScaleSliceSizeN =
ScaleBlockN < NPerWmma ? NRepeat
: math::integer_divide_ceil(NPerBlock, ScaleBlockN);
constexpr auto ScaleSliceStrideN =
math::integer_divide_ceil(NWaves * NPerWmma, ScaleBlockN);
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
auto b_thread_offset_n = (get_thread_local_1d_id() % NPerWmma +
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma) /
ScaleBlockN;
constexpr index_t VectorDim =
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 0 : 1;
constexpr index_t VectorSize =
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 1 : ScaleSliceSizeK;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
VectorDim,
VectorSize,
1,
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, 0));
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleType>(
b_scale_thread_desc.GetElementSpaceSize());
using BScale =
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeN,
ScaleSliceStrideN,
ScaleSliceSizeK,
NumberOfBuffers,
1,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_copy),
decltype(b_scale_grid_buf),
decltype(b_scale_thread_buf),
decltype(b_scale_thread_desc)>;
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
}
}
__device__ static index_t GetKBlockPerScale()
{
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
if constexpr(ck::is_same_v<AScaleType, void> && ck::is_same_v<BScaleType, void>)
{
return 0;
}
else
{
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
}
}
template <bool HasMainKBlockLoop,
@@ -539,18 +720,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
const AScaleType* p_a_scale_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args)
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
@@ -562,12 +746,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(problem.StrideScaleB, 1));
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -585,8 +763,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct
auto a_scale_struct = MakeAScale<1>(problem, p_a_scale_grid, block_m_id);
// BScale struct
auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id);
auto b_scale_struct = MakeBScale<1>(problem, p_b_scale_grid, block_n_id);
const index_t num_k_block_per_scale = GetKBlockPerScale();
@@ -594,6 +775,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
@@ -613,8 +795,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args);
epilogue_args,
k_id);
}
// NOTE: Wrapper function to have __global__ function in common
@@ -626,7 +810,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args)
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
@@ -644,18 +829,40 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
splitk_batch_offset.b_k_split_offset[i];
});
const AScaleType* p_a_scale_grid_ptr;
if constexpr(ck::is_same_v<AScaleType, void>)
{
p_a_scale_grid_ptr = karg.p_a_scale_grid;
}
else
{
p_a_scale_grid_ptr = karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset;
}
const BScaleType* p_b_scale_grid_ptr;
if constexpr(ck::is_same_v<BScaleType, void>)
{
p_b_scale_grid_ptr = karg.p_b_scale_grid;
}
else
{
p_b_scale_grid_ptr = karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset;
}
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_a_scale_grid_ptr,
p_b_scale_grid_ptr,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
epilogue_args,
k_id);
}
};

View File

@@ -69,6 +69,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
const index_t k_id = blockIdx.z * num_k_per_block;
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
@@ -162,7 +204,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
static constexpr index_t KInner = IsBPreShuffled ? KInnerB : ck::math::min(KInnerA, KInnerB);
static constexpr index_t KPack =
KInner *
@@ -966,6 +1008,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
typename BGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AScaleStruct,
typename BScaleStruct,
typename EpilogueArgument,
bool HasMainKBlockLoop,
@@ -988,6 +1031,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const index_t& block_m_id,
const index_t& block_n_id,
const index_t& num_k_block_per_scale,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
@@ -1072,6 +1116,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
a_scale_struct,
b_scale_struct,
num_k_block_main_loop,
num_k_block_per_scale);

View File

@@ -43,13 +43,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid,
karg.p_b_grid,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset,
p_shared,
karg,
karg.a_element_op,
@@ -405,31 +407,33 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
}
}
__host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
__host__ __device__ static constexpr auto
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
{
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1));
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM));
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
}
}
__host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
__host__ __device__ static constexpr auto
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
{
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1));
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN));
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
}
}
@@ -548,6 +552,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t StrideScaleA_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
N{N_},
@@ -556,6 +562,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideC{StrideC_},
StrideScaleA{StrideScaleA_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
@@ -585,7 +593,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideC;
index_t StrideScaleA;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
index_t NPadded;
@@ -611,13 +620,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t StrideScaleA_,
index_t StrideScaleB_,
const AScaleType* p_a_scale_grid_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
: Problem{M_,
N_,
K_,
StrideA_,
StrideB_,
StrideDs_,
StrideC_,
StrideScaleA_,
StrideScaleB_,
k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{},
@@ -673,6 +693,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
b_k_split_offset = blockIdx.z * karg.KRead;
}
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
scale_a_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
scale_a_k_split_offset =
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_b_k_split_offset =
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_b_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
@@ -685,6 +727,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t scale_a_k_split_offset; // A scale matrix offset
index_t scale_b_k_split_offset; // B scale matrix offset
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -1221,8 +1265,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
const auto a_scale_grid_desc_am_ak =
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
const auto b_scale_grid_desc_bn_ak =
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(