Merge branch 'develop' into ck_tile/gemm_blockscale_abquant

This commit is contained in:
kensclin
2025-12-12 01:06:46 +08:00
committed by GitHub
179 changed files with 9598 additions and 2091 deletions

View File

@@ -109,65 +109,145 @@ struct BlockwiseGemmWmmaops_pipeline_base
}
};
template <index_t ScaleSliceSizeN,
template <index_t ScaleSliceSizeMN,
index_t ScaleSliceStrideMN,
index_t ScaleSliceSizeK,
index_t NWaves,
index_t ScaleBlockK,
index_t NumberOfBuffers,
index_t RegSizePerWmma,
typename GridDesc,
typename ThreadCopy,
typename GridBuffer,
typename ThreadStaticBuffer,
typename BScaleThreadDesc>
struct BScale
typename ThreadDesc>
struct ABScale
{
__device__ BScale(GridDesc b_scale_grid_desc_,
ThreadCopy b_scale_thread_copy_,
GridBuffer b_scale_grid_buf_)
: b_scale_thread_copy(b_scale_thread_copy_),
b_scale_grid_desc(b_scale_grid_desc_),
b_scale_grid_buf(b_scale_grid_buf_) {};
__device__ ABScale(GridDesc scale_grid_desc_,
ThreadCopy scale_thread_copy_,
GridBuffer scale_grid_buf_)
: scale_thread_copy(scale_thread_copy_),
scale_grid_desc(scale_grid_desc_),
scale_grid_buf(scale_grid_buf_) {};
static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{});
static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block;
static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
static constexpr index_t num_slice_mn = ScaleSliceSizeMN;
static constexpr index_t num_slice_k = ScaleSliceSizeK;
static constexpr index_t reg_size_per_wmma = RegSizePerWmma;
static constexpr auto b_scale_thread_copy_step =
make_tuple(make_multi_index(NWaves * NPerWmma, 0),
make_multi_index(-NPerBlock, 0),
make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
static constexpr auto scale_thread_desc = ThreadDesc{};
static constexpr auto scale_thread_copy_step =
make_tuple(make_multi_index(ScaleSliceStrideMN, 0),
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0),
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN,
ScaleSliceSizeK));
template <index_t NBuffer>
__device__ void GlobalLoad(bool cond)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(n0, Number<0>{}),
b_scale_thread_bufs(Number<NBuffer>{}));
static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) {
scale_thread_copy.Run(scale_grid_desc,
scale_grid_buf,
scale_thread_desc,
make_tuple(m0 * Number<RegSizePerWmma>{}, Number<0>{}),
scale_thread_bufs(Number<NBuffer>{}));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<0>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<0>{}));
});
if(cond)
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<2>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<2>{}));
}
else
{
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
b_scale_thread_copy_step.At(Number<1>{}));
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
scale_thread_copy_step.At(Number<1>{}));
}
}
ThreadCopy b_scale_thread_copy;
GridDesc b_scale_grid_desc;
GridBuffer b_scale_grid_buf;
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> b_scale_thread_bufs;
ThreadCopy scale_thread_copy;
GridDesc scale_grid_desc;
GridBuffer scale_grid_buf;
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> scale_thread_bufs;
};
template <typename AScaleStruct, typename BScaleStruct>
struct CScale
{
__device__ CScale() {}
static constexpr auto reg_size_per_wmma =
ck::is_same_v<BScaleStruct, Empty> && ck::is_same_v<AScaleStruct, Empty>
? 1
: wmma_gemm.GetRegSizePerWmma();
static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{},
Number<AScaleStruct::num_slice_mn>{},
Number<BScaleStruct::num_slice_mn>{}));
using CScaleThreadDesc = decltype(c_scale_thread_desc);
static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
using ThreadStaticBuffer = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
c_scale_thread_desc.GetElementSpaceSize()));
__device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct)
{
using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc);
using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc);
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
constexpr index_t c_offset =
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
constexpr index_t a_offset =
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
constexpr index_t b_offset =
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
c_scale_thread_bufs(I0)(Number<c_offset>{}) =
a_scale_struct.scale_thread_bufs(I0)[Number<a_offset>{}] *
b_scale_struct.scale_thread_bufs(I0)[Number<b_offset>{}];
});
});
});
}
__device__ void Clear()
{
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
}
template <index_t k_index, index_t m_index, index_t n_index, typename CThreadBuf>
__device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf)
{
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t));
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple(
k_index,
(m_index * num_scale_m_block / MRepeat) % num_scale_m_block +
(Number<t / (reg_size_per_wmma / AScaleStruct::reg_size_per_wmma)>{}) %
AScaleStruct::reg_size_per_wmma,
(n_index * num_scale_n_block / NRepeat) % num_scale_n_block));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()[Number<t>{}] *
type_convert<AccDataType>(c_scale_thread_bufs(I0)[Number<cscale_offset>{}]);
});
}
StaticallyIndexedArray<ThreadStaticBuffer, Number<1>{}> c_scale_thread_bufs;
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, AccDataType, 1, reg_size_per_wmma, true>
c_thread_buf_per_scale;
};
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }

View File

@@ -174,7 +174,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -188,7 +190,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -207,6 +209,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -217,6 +220,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
c_thread_buf.Clear();
auto blockwise_gemm_func = [&]() {
// Local load
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
@@ -245,7 +249,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
b_scale_struct.scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -366,6 +370,189 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
Base::a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
Base::b_thread_desc_.GetElementSpaceSize());
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Initialize C
c_thread_buf.Clear();
auto blockwise_gemm_func = [&]() {
// Local load
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
Base::a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
Base::b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
};
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
block_sync_lds();
blockwise_gemm_func();
block_sync_lds();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
i += 1;
} while(i < (num_loop - 1));
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
block_sync_lds();
blockwise_gemm_func();
}
}
protected:
// A[MRepeat, I1, I1, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
@@ -528,6 +715,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
return TailNumber::Full;
}
template <typename AScaleStruct, typename BScaleStruct>
struct KLoopParams
{
static constexpr auto KRepeatNoScale = 1;
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
};
template <>
struct KLoopParams<Empty, Empty>
{
static constexpr index_t KRepeatNoScale = KRepeatPerCluster;
static constexpr index_t NumScaleKBlock = 1;
static constexpr index_t KRepeatPerNumScaleKBlock = 1;
};
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
@@ -543,7 +747,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -557,7 +763,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -576,6 +782,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -615,7 +822,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(I0)[Number<
b_scale_struct.scale_thread_bufs(I0)[Number<
n0 * BScaleStruct::num_scale_k_block +
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -704,6 +911,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
@@ -982,7 +1190,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -996,7 +1206,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer&,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct&,
index_t num_loop,
index_t) const
@@ -1319,6 +1529,248 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc&,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer&,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
auto gemm_core_func = [&](auto reg_buf) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_bufs[reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
I0,
I0,
n0,
I0,
k_index,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
};
auto a_local_prefetch_func = [&]() {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, KRepeat, 1>{}([&](auto k0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
});
};
// Global prefetch A1 B1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(I0));
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
__builtin_amdgcn_sched_barrier(0);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill A1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
// Global prefetch A2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
// Local prefetch A1
block_sync_lds();
a_local_prefetch_func();
// Initialize C
c_thread_buf.Clear();
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(local_read_buf));
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
a_scale_struct.template GlobalLoad<0>(
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
gemm_core_func(wmma_reg_buf);
block_sync_lds();
// loop prefetch copy
a_local_prefetch_func();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
};
LoopFunc(I0, I1);
LoopFunc(I1, I0);
i += 2;
} while(i < (num_loop - 2));
}
// tail
if constexpr(TailNum == TailNumber::Even)
{
b_blockwise_copy.Run(b_grid_desc,
b_grid_buf,
b_block_desc_k0_n0_n1_n2_k1,
b_block_origin_idx,
b_thread_bufs(I1));
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
gemm_core_func(I0);
block_sync_lds();
// tail Local Prefetch A1
a_local_prefetch_func();
c_scale_struct.Load(a_scale_struct, b_scale_struct);
__builtin_amdgcn_sched_barrier(0);
gemm_core_func(I1);
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
else if constexpr(TailNum == TailNumber::Odd)
{
gemm_core_func(I0);
}
}
protected:
static constexpr auto b_thread_desc_ =
make_naive_tensor_descriptor_packed(make_tuple(Number<KPack / B_K1 / B_KRow>{},

View File

@@ -123,6 +123,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
KInner,
TransposeC>;
using Base::I0;
using Base::I1;
using Base::I2;
using Base::I3;
using Base::A_K1;
using Base::A_KRow;
@@ -322,7 +325,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
b_scale_struct.scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
@@ -348,7 +351,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename BScaleStruct>
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
@@ -362,7 +367,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
// BScaleThreadCopy
AScaleStruct&,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
@@ -383,6 +388,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
// Local prefill 1
@@ -611,6 +617,339 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
}
}
template <bool HasMainLoop,
TailNumber TailNum,
typename AGridDesc,
typename ABlockDesc,
typename ABlockTransfer,
typename AGridBuffer,
typename ABlockBuffer,
typename ABlockTransferStep,
typename BGridDesc,
typename BBlockDesc,
typename BBlockTransfer,
typename BGridBuffer,
typename BBlockBuffer,
typename BBlockTransferStep,
typename CThreadBuffer,
typename AScaleStruct,
typename BScaleStruct,
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
!ck::is_same_v<BScaleStruct, Empty>,
bool>::type = false>
__device__ void Run(const AGridDesc& a_grid_desc,
const ABlockDesc& a_block_desc,
ABlockTransfer& a_blockwise_copy,
const AGridBuffer& a_grid_buf,
ABlockBuffer& a_block_buf,
const ABlockTransferStep& a_block_copy_step,
const BGridDesc& b_grid_desc,
const BBlockDesc& b_block_desc,
BBlockTransfer& b_blockwise_copy,
const BGridBuffer& b_grid_buf,
BBlockBuffer& b_block_buf,
const BBlockTransferStep& b_block_copy_step,
CThreadBuffer& c_thread_buf,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
index_t num_loop,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
static constexpr auto NumScaleKBlock =
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_thread_desc_.GetElementSpaceSize());
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
auto c_scale_struct = CScaleStruct{};
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
// Scales global load
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
c_scale_struct.Load(a_scale_struct, b_scale_struct);
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// Global prefetch 2, perform when at least 2 loops exist.
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
{
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
}
// Initialize C
c_thread_buf.Clear();
// Local prefetch 1
block_sync_lds();
auto local_load_func = [&]() {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
});
};
local_load_func();
__builtin_amdgcn_sched_barrier(0);
// Main body, perform when at least 3 loops exist.
if constexpr(HasMainLoop)
{
index_t i = 0;
do
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale
.GetVectorTypeReference(Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
c_scale_struct.Load(a_scale_struct, b_scale_struct);
block_sync_lds();
local_load_func();
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
} while(i < (num_loop - 2));
}
// Pre-tail, perform when at least 2 loops exist.
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
{
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
// No RunRead or MoveSrcSliceWindow here, already finished them all!
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
c_scale_struct.Load(a_scale_struct, b_scale_struct);
block_sync_lds();
local_load_func();
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
// Tail, always perform.
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
c_scale_struct.Clear();
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k_index,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
constexpr index_t k_index =
kscale0 * (KRepeat / NumScaleKBlock) + k0;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k_index,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
Number<0>{}));
});
});
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
});
});
});
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
// latency
// __builtin_amdgcn_sched_barrier(0);
}
}
protected:
using Base::a_thread_copy_;
using Base::a_thread_desc_;

View File

@@ -105,6 +105,353 @@ struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator
virtual int GetPreShuffleParameters() = 0;
};
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual int GetPreShuffleParameters() = 0;
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK in contexts where
// DeviceGemmMultipleD_BlockScale_BPreshuffle is expected.
///
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
/// The only difference between API of DeviceGemmMultipleD_BlockScale_BPreshuffle and
// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK is
/// that DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK::MakeArgumentPointer requires
// an additional parameter KBatch which is explicitly passed as 1 by this wrapper.
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper
: public DeviceGemmMultipleD_BlockScale_BPreshuffle<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
AScaleType,
BDataType,
BScaleType,
DsDataType,
EDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
AScaleType,
BDataType,
BScaleType,
DsDataType,
EDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
explicit DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper(std::unique_ptr<DeviceOp> p_op)
: p_op_(std::move(p_op))
{
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return p_op_->IsSupportedArgument(p_arg);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return p_op_->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
p_a_scale,
p_b_scale,
a_element_op,
b_element_op,
cde_element_op,
1); // KBatch
}
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return p_op_->MakeInvokerPointer();
}
int GetPreShuffleParameters() override { return p_op_->GetPreShuffleParameters(); }
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
private:
std::unique_ptr<DeviceOp> p_op_;
#endif // __HIPCC_RTC__
};
// GEMM:
// input : A[M, K], B[K, N],
// input : D0[M, N], D1[M, N], ...
// output : E[M, N]
// C = a_op(A) * b_op(B)
// E = cde_op(C, D0, D1, ...)
// Assume:
// D0, D1, ... and E have the same layout
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_ABScaleSplitK : public BaseOperator
{
static constexpr index_t NumDTensor = DsDataType::Size();
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
index_t KBatch) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0;
};
/// @brief Wrapper for backward compatibility that allows to use instances of
/// DeviceGemmMultipleD_ABScaleSplitK in contexts where DeviceGemmMultipleD_ABScale is
/// expected.
///
/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances().
/// The only difference between API of DeviceGemmMultipleD_ABScale and
/// DeviceGemmMultipleD_ABScaleSplitK is that
/// DeviceGemmMultipleD_ABScaleSplitK::MakeArgumentPointer requires a additional parameter
/// KBatch which is explicitly passed as 1 by this wrapper.
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename ELayout,
typename ADataType,
typename AScaleType,
typename BDataType,
typename BScaleType,
typename DsDataType,
typename EDataType,
index_t ScaleBlockM,
index_t ScaleBlockN,
index_t ScaleBlockK,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CDEElementwiseOperation>
struct DeviceGemmMultipleD_ABScaleSplitKWrapper
: public DeviceGemmMultipleD_ABScale<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
AScaleType,
BDataType,
BScaleType,
DsDataType,
EDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>
{
using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK<ALayout,
BLayout,
DsLayout,
ELayout,
ADataType,
AScaleType,
BDataType,
BScaleType,
DsDataType,
EDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation>;
static constexpr index_t NumDTensor = DsDataType::Size();
#ifndef __HIPCC_RTC__
explicit DeviceGemmMultipleD_ABScaleSplitKWrapper(std::unique_ptr<DeviceOp> p_op)
: p_op_(std::move(p_op))
{
}
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return p_op_->IsSupportedArgument(p_arg);
}
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
const ck::index_t M,
const ck::index_t N,
const ck::index_t K,
const ck::index_t StrideA,
const ck::index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
const ck::index_t StrideE,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op) override
{
return p_op_->MakeArgumentPointer(p_a,
p_b,
p_ds,
p_e,
M,
N,
K,
StrideA,
StrideB,
StrideDs,
StrideE,
p_a_scale,
p_b_scale,
a_element_op,
b_element_op,
cde_element_op,
1); // KBatch
}
void SetKBatch(BaseArgument* arg, int KBatch) const override { p_op_->SetKBatch(arg, KBatch); }
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return p_op_->MakeInvokerPointer();
}
std::string GetTypeString() const override { return p_op_->GetTypeString(); }
private:
std::unique_ptr<DeviceOp> p_op_;
#endif // __HIPCC_RTC__
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
@@ -93,7 +93,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
p_bs_grid_shift,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset,
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset,
karg.p_a_scale_grid,
karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset,
p_shared,
karg,
karg.a_element_op,
@@ -315,12 +316,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
};
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
Tuple<ADataType>,
void, // AScaleType
Tuple<BDataType>,
BScaleDataType,
AccDataType,
@@ -332,6 +334,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
CElementwiseOperation,
GemmSpec,
BlockSize,
0, // ScaleBlockM
ScaleBlockN,
ScaleBlockK,
MPerBlock,
@@ -405,7 +408,9 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale
std::array<index_t, 1>{StrideB_},
std::array<index_t, 0>{}, // StrideDs_
StrideC_,
0, // StrideScaleA
StrideScaleB_,
nullptr,
p_b_scale_grid_,
k_batch_,
a_element_op_,

View File

@@ -0,0 +1,362 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM, // scale block for M
index_t ScaleBlockN, // scale block for N
index_t ScaleBlockK, // scale block for K
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3
: public DeviceGemmMultipleD_ABScaleSplitK<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
ALayout,
BLayout,
DsLayout,
CLayout,
Tuple<ADataType>,
AScaleDataType,
Tuple<BDataType>,
BScaleDataType,
AccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB>;
using Argument = typename GridwiseGemm::Argument;
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
Tuple<ADataType>,
Tuple<BDataType>,
DsDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
CShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
// with splitk the implementation doesn't work
// when KRead % ScaleBlockK != 0, independently of K padding
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
{
return false;
}
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
void SetKBatch(BaseArgument* base_arg, int KBatch) const override
{
auto& arg = *dynamic_cast<Argument*>(base_arg);
arg.KBatch = KBatch;
arg.KRead = GridwiseGemm::CalculateKRead(arg.K, KBatch);
arg.KPadded = GridwiseGemm::CalculateKPadded(arg.K, KBatch);
arg.AK0 = GridwiseGemm::CalculateAK0Padded(arg.K, KBatch);
arg.BK0 = GridwiseGemm::CalculateBK0Padded(arg.K, KBatch);
}
static auto MakeArgument(const ADataType* p_a,
const BDataType* p_b,
std::array<const void*, NumDTensor> p_ds,
CDataType* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
const BScaleDataType* p_a_scale,
const BScaleDataType* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op,
index_t KBatch = 1)
{
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return Argument{std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
p_c,
M,
N,
K,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
p_a_scale,
p_b_scale,
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_c,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch = 1) override
{
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<CDataType*>(p_c),
M,
N,
K,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemm_ABScale_Wmma_CShuffleV3"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma<<"x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat<<"x" << NRepeat<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -18,52 +18,6 @@
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
const index_t k_id = blockIdx.z * num_k_per_block;
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
} // namespace ck
namespace ck {
namespace tensor_operation {
namespace device {
@@ -202,270 +156,14 @@ struct DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB>;
ComputeTypeB,
true>; // IsBPreshuffle
// Invoker
struct Invoker : public BaseInvoker
{
/// @brief This function issues GPU kernel execution.
/// @param arg The GPU kernel arguments.
/// @param stream_config The HIP stream configuration helper structure.
/// @return The kernel's average execution time (if time measurement is
/// enabled).
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
if(stream_config.log_level_ > 0)
{
arg.Print();
GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print();
}
if(!GridwiseGemm::CheckValidity(arg))
{
throw std::runtime_error("wrong! GridwiseGemm has invalid setting");
}
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch);
float ave_time = 0;
index_t k_grain = arg.KBatch * KPerBlock;
index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock;
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split);
const auto Run = [&](const auto& kernel) {
if(stream_config.flush_cache)
{
Argument arg_ = arg;
const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1(
arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0);
const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1(
arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0);
std::array<std::size_t, 1> size_as_buffers;
size_as_buffers[Number<0>{}] =
a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() *
sizeof(ADataType) / GridwiseGemm::APackedSize;
std::array<std::size_t, 1> size_bs_buffers;
size_bs_buffers[Number<0>{}] =
b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() *
sizeof(BDataType) / GridwiseGemm::BPackedSize;
const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N(
arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs);
std::array<std::size_t, GridwiseGemm::NumDTensor> size_ds_buffers;
static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) {
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
size_ds_buffers[i] =
ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType);
});
ck::utility::RotatingMemWrapperMultiABD<Argument,
Tuple<ADataType>,
Tuple<BDataType>,
DsDataType>
rotating_mem(arg_,
stream_config.rotating_count,
size_as_buffers,
size_bs_buffers,
size_ds_buffers);
rotating_mem.Print();
auto run_flush_cache = [&]() {
// flush icache
ck::utility::flush_icache();
// rotating mem
rotating_mem.Next();
// clear c mem
if(arg_.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid,
0,
arg_.M * arg_.N * sizeof(EDataType),
stream_config.stream_id_));
};
ave_time = ck::utility::launch_and_time_kernel_with_preprocess<false>(
stream_config,
run_flush_cache,
kernel,
dim3(gdx, gdy, gdz),
dim3(BlockSize),
0,
arg_);
}
else
{
if(arg.KBatch > 1)
HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid,
0,
arg.M * arg.N * sizeof(EDataType),
stream_config.stream_id_));
ave_time = launch_and_time_kernel(
stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg);
}
};
constexpr index_t minimum_occupancy = []() {
if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave)
{
return 2;
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1;
}
else
{
return 1;
}
}();
// ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is
// currently implemented in such a way that all SrcScalarPerVectors must be the same, so
// if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the
// other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot
// be odd.
constexpr bool AtomicsImplementationExists =
!(std::is_same_v<EDataType, ck::half_t> || std::is_same_v<EDataType, ck::bhalf_t> ||
std::is_same_v<EDataType, int8_t>) ||
(CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0);
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
return ave_time;
}
// polymorphic
float Run(const BaseArgument* p_arg,
const StreamConfig& stream_config = StreamConfig{}) override
{
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
}
};
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
return DeviceGemmCommon::IsSupportedArgument(arg);
}

View File

@@ -0,0 +1,360 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <iostream>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
template <typename ALayout,
typename BLayout,
typename DsLayout,
typename CLayout,
typename ADataType,
typename AScaleDataType,
typename BDataType,
typename BScaleDataType,
typename DsDataType,
typename CDataType,
typename AccDataType,
typename CShuffleDataType,
typename AElementwiseOperation,
typename BElementwiseOperation,
typename CElementwiseOperation,
GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM, // scale block for M
index_t ScaleBlockN, // scale block for N
index_t ScaleBlockK, // scale block for K
index_t MPerBlock,
index_t NPerBlock,
index_t KPerBlock,
index_t AK1,
index_t BK1,
index_t MPerWmma,
index_t NPerWmma,
index_t MRepeat,
index_t NRepeat,
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
typename ABlockTransferThreadClusterArrangeOrder,
typename ABlockTransferSrcAccessOrder,
index_t ABlockTransferSrcVectorDim,
index_t ABlockTransferSrcScalarPerVector,
index_t ABlockTransferDstScalarPerVector_AK1,
bool ABlockLdsExtraM,
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
typename BBlockTransferThreadClusterArrangeOrder,
typename BBlockTransferSrcAccessOrder,
index_t BBlockTransferSrcVectorDim,
index_t BBlockTransferSrcScalarPerVector,
index_t BBlockTransferDstScalarPerVector_BK1,
bool BBlockLdsExtraN,
index_t CShuffleMRepeatPerShuffle,
index_t CShuffleNRepeatPerShuffle,
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v1,
typename ComputeTypeA = CDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle
: public DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK<ALayout,
BLayout,
DsLayout,
CLayout,
ADataType,
AScaleDataType,
BDataType,
BScaleDataType,
DsDataType,
CDataType,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation>
{
static constexpr index_t NumDTensor = DsDataType::Size();
using AScaleLayout = tensor_layout::gemm::ColumnMajor;
using BScaleLayout = BLayout;
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
ALayout,
BLayout,
DsLayout,
CLayout,
Tuple<ADataType>,
AScaleDataType,
Tuple<BDataType>,
BScaleDataType,
AccDataType,
CShuffleDataType,
DsDataType,
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
GemmSpec,
BlockSize,
ScaleBlockM,
ScaleBlockN,
ScaleBlockK,
MPerBlock,
NPerBlock,
KPerBlock,
AK1,
BK1,
MPerWmma,
NPerWmma,
MRepeat,
NRepeat,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
ABlockTransferThreadClusterArrangeOrder,
ABlockTransferSrcAccessOrder,
ABlockTransferSrcVectorDim,
ABlockTransferSrcScalarPerVector,
ABlockTransferDstScalarPerVector_AK1,
false,
ABlockLdsExtraM,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
BBlockTransferThreadClusterArrangeOrder,
BBlockTransferSrcAccessOrder,
BBlockTransferSrcVectorDim,
BBlockTransferSrcScalarPerVector,
BBlockTransferDstScalarPerVector_BK1,
false,
BBlockLdsExtraN,
CShuffleMRepeatPerShuffle,
CShuffleNRepeatPerShuffle,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
PermuteA,
PermuteB,
true,
AScaleLayout,
BScaleLayout>;
using Argument = typename GridwiseGemm::Argument;
int GetPreShuffleParameters() override { return NPerWmma; }
using DeviceGemmCommon =
DeviceGemm_Wmma_CShuffleV3_Common<GridwiseGemm,
Tuple<ADataType>,
Tuple<BDataType>,
DsDataType,
CDataType,
MPerBlock,
NPerBlock,
KPerBlock,
BlockSize,
AK1,
BK1,
GemmSpec,
CShuffleBlockTransferScalarPerVectors,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ComputeTypeA,
ComputeTypeB,
true>; // IsBPreshuffle
// Invoker
using Invoker = typename DeviceGemmCommon::Invoker;
static bool IsSupportedArgument(const Argument& arg)
{
// with splitk the implementation doesn't work
// when KRead % ScaleBlockK != 0, independently of K padding
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
{
return false;
}
return DeviceGemmCommon::IsSupportedArgument(arg);
}
// polymorphic
bool IsSupportedArgument(const BaseArgument* p_arg) override
{
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
}
static auto MakeArgument(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const std::array<index_t, NumDTensor> StrideDs,
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation cde_element_op,
index_t KBatch)
{
index_t StrideScaleA = ck::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BScaleLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return Argument{std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<CDataType*>(p_e),
M,
N,
K,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
b_element_op,
cde_element_op};
}
static auto MakeInvoker() { return Invoker{}; }
// polymorphic
std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_a,
const void* p_b,
std::array<const void*, NumDTensor> p_ds,
void* p_e,
index_t M,
index_t N,
index_t K,
index_t StrideA,
index_t StrideB,
const std::array<ck::index_t, NumDTensor> StrideDs,
index_t StrideC,
const void* p_a_scale,
const void* p_b_scale,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op,
index_t KBatch) override
{
index_t StrideScaleA = ck::is_same_v<AScaleLayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BScaleLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return std::make_unique<Argument>(std::array<const void*, 1>{p_a},
std::array<const void*, 1>{p_b},
p_ds,
static_cast<CDataType*>(p_e),
M,
N,
K,
std::array<index_t, 1>{StrideA},
std::array<index_t, 1>{StrideB},
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,
b_element_op,
c_element_op);
}
// polymorphic
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
{
return std::make_unique<Invoker>(Invoker{});
}
// polymorphic
std::string GetTypeString() const override
{
auto str = std::stringstream();
std::map<BlockGemmPipelineScheduler, std::string> BlkGemmPipelineSchedulerToString{
{BlockGemmPipelineScheduler::Intrawave, "Intrawave"},
{BlockGemmPipelineScheduler::Interwave, "Interwave"}};
std::map<BlockGemmPipelineVersion, std::string> BlkGemmPipelineVersionToString{
{BlockGemmPipelineVersion::v1, "v1"},
{BlockGemmPipelineVersion::v2, "v2"},
{BlockGemmPipelineVersion::v3, "v3"},
{BlockGemmPipelineVersion::v4, "v4"},
{BlockGemmPipelineVersion::v5, "v5"}};
// clang-format off
str << "DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle"
<< "<"
<< getGemmSpecializationString(GemmSpec) << ", "
<< std::string(ALayout::name)[0]
<< std::string(BLayout::name)[0]
<< std::string(CLayout::name)[0]
<< ">"
<< " BlkSize: "
<< BlockSize << ", "
<< "BlkTile: "
<< MPerBlock<<"x"<<NPerBlock<<"x"<<KPerBlock << ", "
<< "WaveTile: "
<< MPerWmma<<"x"<<NPerWmma << ", "
<< "WaveMap: "
<< MRepeat<<"x" << NRepeat<<", "
<< "VmemReadVec: "
<< ABlockTransferSrcScalarPerVector<<"x"<<BBlockTransferSrcScalarPerVector<<", "
<< "BlkGemmPipelineScheduler: "
<< BlkGemmPipelineSchedulerToString[BlkGemmPipeSched] << ", "
<< "BlkGemmPipelineVersion: "
<< BlkGemmPipelineVersionToString[BlkGemmPipelineVer] << ", "
<< "BlkGemmPipelinePrefetchStages: "
<< GridwiseGemm::BlockwiseGemmPipe::PrefetchStages << ", "
<< "KPack: "
<< GridwiseGemm::KPack;
// clang-format on
return str.str();
}
};
} // namespace device
} // namespace tensor_operation
} // namespace ck

View File

@@ -262,6 +262,16 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
@@ -279,22 +289,47 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Full)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel =
kernel_gemm_xdl_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
}
}
}
@@ -315,6 +350,20 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
{
auto& arg = *dynamic_cast<Argument*>(base_arg);
arg.KBatch = KBatch;
if(get_warp_size() == 64)
{
arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch);
arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch);
arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch);
arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch);
}
else
{
arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch);
arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch);
arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch);
arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch);
}
}
static constexpr bool IsValidCompilationParameter()
@@ -325,6 +374,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
static bool IsSupportedArgument(const Argument& arg)
{
// with splitk the implementation doesn't work
// when KRead % ScaleBlockK != 0, independently of K padding
if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0)
{
return false;
}
if(!ck::is_xdl_wmma_supported<ComputeTypeA, ComputeTypeB, MPerXDL, NPerXDL>())
{
return false;
@@ -385,6 +441,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op)
{
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return Argument{static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
@@ -396,6 +460,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
StrideB,
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1,
@@ -425,6 +491,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
BElementwiseOperation b_element_op,
CElementwiseOperation c_element_op) override
{
index_t StrideScaleA = ck::is_same_v<ALayout, tensor_layout::gemm::RowMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(M, ScaleBlockM);
index_t StrideScaleB = ck::is_same_v<BLayout, ck::tensor_layout::gemm::ColumnMajor>
? math::integer_divide_ceil(K, ScaleBlockK)
: math::integer_divide_ceil(N, ScaleBlockN);
return std::make_unique<Argument>(static_cast<const ADataType*>(p_a),
static_cast<const BDataType*>(p_b),
p_ds,
@@ -436,6 +510,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
StrideB,
StrideDs,
StrideC,
StrideScaleA,
StrideScaleB,
static_cast<const AScaleDataType*>(p_a_scale),
static_cast<const BScaleDataType*>(p_b_scale),
1,

View File

@@ -12,7 +12,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/flush_cache.hpp"
@@ -86,12 +86,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
{
// GridwiseGemm
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale<
using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale<
ALayout,
BLayout,
Tuple<>, // DsLayout
CLayout,
Tuple<ADataType>,
void, // AScaleType
Tuple<BDataType>,
BScaleDataType,
AccDataType,
@@ -103,6 +104,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
CElementwiseOperation,
GemmSpec,
BlockSize,
0, // ScaleBlockM
ScaleBlockN,
ScaleBlockK,
MPerBlock,
@@ -207,7 +209,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
std::array<index_t, 1>{StrideB},
std::array<index_t, 0>{}, // StrideDs_
StrideC,
0, // StrideScaleA
StrideScaleB,
nullptr,
p_b_scale,
KBatch,
a_element_op,
@@ -245,7 +249,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale<ALayout,
std::array<index_t, 1>{StrideB},
std::array<index_t, 0>{}, // StrideDs_
StrideC,
0, // StrideScaleA
StrideScaleB,
nullptr, // p_a_scale
static_cast<const BScaleDataType*>(p_b_scale),
KBatch,
a_element_op,

View File

@@ -38,7 +38,8 @@ template <typename GridwiseGemm,
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB>
typename ComputeTypeB,
bool IsBPreShuffled = false>
struct DeviceGemm_Wmma_CShuffleV3_Common
{
@@ -189,61 +190,174 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
if(has_main_k_block_loop)
{
// Tail number always full
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
if constexpr(IsBPreShuffled)
{
if(arg.KBatch > 1)
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if constexpr(AtomicsImplementationExists)
if(arg.KBatch > 1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
if constexpr(AtomicsImplementationExists)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
{
// TODO: Implement
}
}
else
{
// Tail number always 1
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 ||
BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
{
if constexpr(AtomicsImplementationExists)
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
const auto kernel = kernel_gemm_wmma_cshuffle_v3<
GridwiseGemm,
true,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
true,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
else
}
}
else
{
if constexpr(IsBPreShuffled)
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) ==
TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
else
{
if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd)
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Odd>;
Run(kernel);
}
else
{
const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy,
TailNumber::Even>;
Run(kernel);
}
}
}
}
else
{
if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1)
{
if(arg.KBatch > 1)
{
if constexpr(AtomicsImplementationExists)
{
const auto kernel = kernel_gemm_wmma_cshuffle_v3<
GridwiseGemm,
false,
InMemoryDataOperationEnum::AtomicAdd,
minimum_occupancy>;
Run(kernel);
}
}
else
{
const auto kernel =
kernel_gemm_wmma_cshuffle_v3<GridwiseGemm,
false,
InMemoryDataOperationEnum::Set,
minimum_occupancy>;
Run(kernel);
}
}
}
}
@@ -299,6 +413,14 @@ struct DeviceGemm_Wmma_CShuffleV3_Common
return false;
}
if constexpr(IsBPreShuffled)
{
if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0)
{
return false;
}
}
return GridwiseGemm::CheckValidity(arg);
}
};

View File

@@ -388,11 +388,11 @@ struct ABTransferThreadTiles
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(
make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow, Number<1>{})),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(make_unmerge_transform(make_tuple(
Number<ABK0 / KRow>{}, KRow, Number<KPack / KRow / ABK1>{})),
make_unmerge_transform(make_tuple(
Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
}

View File

@@ -895,8 +895,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
c_thread_buf.Clear();
// Empty BScale struct for the blockwise pipeline.
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
using ABScale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = ABScale{};
auto b_scale_struct = ABScale{};
/*******************************************************************************/
//
@@ -919,6 +920,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
b0_block_buf,
b0_block_slice_copy_step,
acc0_thread_buf,
a_scale_struct,
b_scale_struct,
KBlockMainLoop,
1); // num_k_block_per_scale

View File

@@ -618,8 +618,9 @@ struct GridwiseGemm_wmma_cshuffle_v3
__builtin_amdgcn_readfirstlane(block_work_idx[Number<BlockMapNBlockIndex>{}]);
// BScale struct (Empty)
using BScale = typename BlockwiseGemmPipe::Empty;
auto b_scale_struct = BScale{};
using Scale = typename BlockwiseGemmPipe::Empty;
auto a_scale_struct = Scale{};
auto b_scale_struct = Scale{};
const index_t num_k_block_per_scale = GetKBlockPerScale();
@@ -627,6 +628,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
@@ -646,6 +648,7 @@ struct GridwiseGemm_wmma_cshuffle_v3
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args,
k_id);

View File

@@ -23,6 +23,7 @@ template <typename ALayout,
typename DsLayout,
typename ELayout,
typename AsDataType,
typename AScaleType,
typename BsDataType,
typename BScaleType,
typename AccDataType,
@@ -34,6 +35,7 @@ template <typename ALayout,
typename CDEElementwiseOperation,
tensor_operation::device::GemmSpecialization GemmSpec,
index_t BlockSize,
index_t ScaleBlockM,
index_t ScaleBlockN, // scale N
index_t ScaleBlockK, // scale K
index_t MPerBlock,
@@ -65,13 +67,16 @@ template <typename ALayout,
index_t CShuffleNRepeatPerShuffle,
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
typename CDEShuffleBlockTransferScalarPerVectors,
BlockGemmPipelineScheduler BlkGemmPipeSched = BlockGemmPipelineScheduler::Intrawave,
BlockGemmPipelineVersion BlkGemmPipelineVer = BlockGemmPipelineVersion::v4,
typename ComputeTypeA = EDataType,
typename ComputeTypeB = ComputeTypeA,
bool PermuteA = false,
bool PermuteB = false>
struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BlockGemmPipelineScheduler BlkGemmPipeSched,
BlockGemmPipelineVersion BlkGemmPipelineVer,
typename ComputeTypeA,
typename ComputeTypeB,
bool PermuteA,
bool PermuteB,
bool IsBPreShuffled = false,
typename AScaleLayout = ALayout,
typename BScaleLayout = BLayout>
struct GridwiseGemm_wmma_cshuffle_v3_ab_scale
: GridwiseGemm_wmma_cshuffle_v3_base<
ALayout,
BLayout,
@@ -123,7 +128,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
ComputeTypeB,
PermuteA,
PermuteB,
false,
IsBPreShuffled,
true>
{
using Base = GridwiseGemm_wmma_cshuffle_v3_base<
@@ -177,7 +182,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
ComputeTypeB,
PermuteA,
PermuteB,
false,
IsBPreShuffled,
true>;
using Base::I0;
@@ -233,6 +238,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleA_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
@@ -242,6 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
StrideBs{StrideBs_},
StrideDs{StrideDs_},
StrideE{StrideE_},
StrideScaleA{StrideScaleA_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
@@ -251,7 +258,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
AK0{CalculateAK0Padded(K_, KBatch_)},
BK0{CalculateBK0Padded(K_, KBatch_)},
MBlock{CalculateMBlock(M_)},
NBlock{CalculateNBlock(N_)}
NBlock{CalculateNBlock(N_)},
Kt{K_}
{
}
@@ -275,11 +283,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
});
std::cout << " }, ";
}
std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", "
<< "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead
<< ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0
<< ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}"
<< std::endl;
std::cout << "SE:" << StrideE << ", " << "SScaleA:" << StrideScaleA << ", "
<< "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", "
<< "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded
<< ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", "
<< "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl;
}
index_t M;
@@ -289,6 +297,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideE;
index_t StrideScaleA;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
@@ -299,6 +308,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
index_t BK0;
index_t MBlock;
index_t NBlock;
index_t Kt;
};
// Argument
@@ -315,7 +325,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumBTensor> StrideBs_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideE_,
index_t StrideScaleA_,
index_t StrideScaleB_,
const AScaleType* p_a_scale_grid_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
@@ -329,12 +341,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
StrideBs_,
StrideDs_,
StrideE_,
StrideScaleA_,
StrideScaleB_,
k_batch_},
p_as_grid{},
p_bs_grid{},
p_ds_grid{},
p_e_grid{p_e_grid_},
p_a_scale_grid{p_a_scale_grid_},
p_b_scale_grid{p_b_scale_grid_},
a_element_op{a_element_op_},
b_element_op{b_element_op_},
@@ -379,6 +393,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
DsGridPointer p_ds_grid;
EDataType* p_e_grid;
const AScaleType* p_a_scale_grid;
const BScaleType* p_b_scale_grid;
const AElementwiseOperation a_element_op;
const BElementwiseOperation b_element_op;
@@ -407,34 +422,52 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
[&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; });
}
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
if constexpr(IsBPreShuffled)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; });
static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; });
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
else
{
if constexpr(!PermuteB)
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
static_for<0, NumBTensor, 1>{}([&](auto i) {
b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i];
});
}
else
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
const int k0_offset = karg.KRead * karg.N;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
if constexpr(!PermuteB)
{
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; });
}
else
{
const int k0_offset = karg.KRead * karg.N;
static_for<0, NumBTensor, 1>{}(
[&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; });
}
}
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, AScaleLayout>)
{
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB;
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, AScaleLayout>)
{
scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BScaleLayout>)
{
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BScaleLayout>)
{
scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK);
}
if(k_id < karg.KBatch - 1)
@@ -458,77 +491,225 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
std::array<index_t, NumATensor> a_k_split_offset;
std::array<index_t, NumBTensor> b_k_split_offset;
index_t scale_k_split_offset; // New member for scale matrix offset
index_t scale_a_k_split_offset; // A scale matrix offset
index_t scale_b_k_split_offset; // B scale matrix offset
index_t c_reduce_offset;
};
using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe;
// return block_id to C matrix tile idx (m0, n0) mapping
// if arch = gfx942
using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>;
// using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit<MPerBlock, NPerBlock>;
template <index_t NumberOfBuffers, typename BScaleGridDesc_BN_AK>
__device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak,
const BScaleType* p_b_scale_grid,
index_t block_n_id)
__device__ static constexpr auto
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
{
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::RowMajor, AScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
}
}
static constexpr auto wmma =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma;
template <index_t NumberOfBuffers>
__device__ static auto
MakeAScale(const Problem& problem, const AScaleType* p_a_scale_grid, index_t block_m_id)
{
if constexpr(ck::is_same_v<AScaleType, void>)
{
using AScale = typename BlockwiseGemmPipe::Empty;
return AScale{};
}
else
{
#if defined(__gfx11__)
// TODO: remove this restriction
static_assert(ScaleBlockM >= MPerWmma,
"ScaleBlockM must be greater equal than MPerWmma");
#endif
static_assert(
ScaleBlockK >=
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
selected_wmma.k_per_wmma,
"ScaleBlockK must be greater equal than KPerWmma");
static constexpr auto ScaleSliceSizeN = NRepeat;
static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK;
const auto a_scale_grid_desc_am_ak =
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
const auto a_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize());
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
constexpr auto wmma =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>{};
constexpr auto RegSizePerWmmaFull =
wmma.selected_wmma.num_acc_vgprs_per_wave * wmma.selected_wmma.acc_pack_number;
constexpr auto RegSizePerWmma =
math::integer_divide_ceil(RegSizePerWmmaFull, ScaleBlockM);
auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma +
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma;
auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread;
constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma);
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
1,
ScaleSliceSizeK,
1,
false>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n,
b_thread_offset_k / ScaleBlockK));
constexpr auto ScaleSliceSizeM =
ScaleBlockM < MPerWmma ? MRepeat * RegSizePerWmma
: math::integer_divide_ceil(MPerBlock, ScaleBlockM);
constexpr auto ScaleSliceStrideM =
math::integer_divide_ceil(MWaves * MPerWmma, ScaleBlockM);
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
b_scale_thread_desc.GetElementSpaceSize());
constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeM>{}, Number<ScaleSliceSizeK>{}));
using BScale =
typename BlockwiseGemmPipe::template BScale<ScaleSliceSizeN,
ScaleSliceSizeK,
NWaves,
ScaleBlockK,
NumberOfBuffers,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_copy),
decltype(b_scale_grid_buf),
decltype(b_scale_thread_buf),
decltype(b_scale_thread_desc)>;
auto a_thread_offset_m =
((get_thread_local_1d_id() % 32) / MPerWmma * RegSizePerWmma) /
math::integer_divide_ceil(ScaleBlockM, RegSizePerWmmaFull) +
(get_thread_local_1d_id() / 32) / NWaves * MPerWmma / ScaleBlockM;
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
constexpr index_t VectorDim =
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? 0 : 1;
constexpr index_t VectorSize =
is_same<tensor_layout::gemm::ColumnMajor, AScaleLayout>::value ? RegSizePerWmma
: ScaleSliceSizeK;
auto a_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<AScaleType,
AScaleType,
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_desc),
Sequence<RegSizePerWmma, ScaleSliceSizeK>,
Sequence<0, 1>,
VectorDim,
VectorSize,
1,
true>(
a_scale_grid_desc_am_ak,
make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset_m, 0));
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AScaleType>(
a_scale_thread_desc.GetElementSpaceSize());
using AScale =
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeM,
ScaleSliceStrideM,
ScaleSliceSizeK,
NumberOfBuffers,
RegSizePerWmma,
decltype(a_scale_grid_desc_am_ak),
decltype(a_scale_thread_copy),
decltype(a_scale_grid_buf),
decltype(a_scale_thread_buf),
decltype(a_scale_thread_desc)>;
return AScale{a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf};
}
}
__device__ static constexpr auto
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
{
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
}
}
template <index_t NumberOfBuffers>
__device__ static auto
MakeBScale(const Problem& problem, const BScaleType* p_b_scale_grid, index_t block_n_id)
{
if constexpr(ck::is_same_v<BScaleType, void>)
{
using BScale = typename BlockwiseGemmPipe::Empty;
return BScale{};
}
else
{
static_assert(
ScaleBlockK >=
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::
selected_wmma.k_per_wmma,
"ScaleBlockK must be greater equal than KPerWmma");
const auto b_scale_grid_desc_bn_ak =
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
const auto b_scale_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize());
constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma);
constexpr auto ScaleSliceSizeN =
ScaleBlockN < NPerWmma ? NRepeat
: math::integer_divide_ceil(NPerBlock, ScaleBlockN);
constexpr auto ScaleSliceStrideN =
math::integer_divide_ceil(NWaves * NPerWmma, ScaleBlockN);
constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK);
constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed(
make_tuple(Number<ScaleSliceSizeN>{}, Number<ScaleSliceSizeK>{}));
auto b_thread_offset_n = (get_thread_local_1d_id() % NPerWmma +
(get_thread_local_1d_id() / 32) % NWaves * NPerWmma) /
ScaleBlockN;
constexpr index_t VectorDim =
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 0 : 1;
constexpr index_t VectorSize =
is_same<tensor_layout::gemm::RowMajor, BScaleLayout>::value ? 1 : ScaleSliceSizeK;
auto b_scale_thread_copy =
ThreadwiseTensorSliceTransfer_v2<BScaleType,
BScaleType,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_desc),
Sequence<1, ScaleSliceSizeK>,
Sequence<0, 1>,
VectorDim,
VectorSize,
1,
true>(
b_scale_grid_desc_bn_ak,
make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, 0));
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, BScaleType>(
b_scale_thread_desc.GetElementSpaceSize());
using BScale =
typename BlockwiseGemmPipe::template ABScale<ScaleSliceSizeN,
ScaleSliceStrideN,
ScaleSliceSizeK,
NumberOfBuffers,
1,
decltype(b_scale_grid_desc_bn_ak),
decltype(b_scale_thread_copy),
decltype(b_scale_grid_buf),
decltype(b_scale_thread_buf),
decltype(b_scale_thread_desc)>;
return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf};
}
}
__device__ static index_t GetKBlockPerScale()
{
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
if constexpr(ck::is_same_v<AScaleType, void> && ck::is_same_v<BScaleType, void>)
{
return 0;
}
else
{
return (ScaleBlockK + KPerBlock - 1) / KPerBlock;
}
}
template <bool HasMainKBlockLoop,
@@ -539,18 +720,21 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
BsGridPointer& p_bs_grid,
DsGridPointer& p_ds_grid,
EDataType* p_e_grid,
const AScaleType* p_a_scale_grid,
const BScaleType* p_b_scale_grid,
void* p_shared,
const Problem& problem,
AElementwiseOperation a_element_op,
BElementwiseOperation b_element_op,
CDEElementwiseOperation cde_element_op,
EpilogueArgument& epilogue_args)
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
{
const auto as_grid_desc_ak0_m_ak1 = MakeAsGridDescriptor_AK0_M_AK1(
problem.M, problem.MPadded, problem.K, problem.KPadded, problem.StrideAs, problem.AK0);
const index_t K_b = IsBPreShuffled ? problem.Kt : problem.K;
const auto bs_grid_desc_bk0_n_bk1 = MakeBsGridDescriptor_BK0_N_BK1(
problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
K_b, problem.KPadded, problem.N, problem.NPadded, problem.StrideBs, problem.BK0);
const auto ds_grid_desc_m_n = MakeDsGridDescriptor_M_N(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideDs);
const auto e_grid_desc_m_n = Base::template MakeDEGridDescriptor_M_N<ELayout>(
@@ -562,12 +746,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
e_grid_desc_m_n, problem.MBlock, problem.NBlock);
// B Scale grid
const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor(
make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN),
math::integer_divide_ceil(problem.K, ScaleBlockK)),
make_tuple(problem.StrideScaleB, 1));
// divide block work by [M, N]
const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4};
@@ -585,8 +763,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]);
// AScale struct
auto a_scale_struct = MakeAScale<1>(problem, p_a_scale_grid, block_m_id);
// BScale struct
auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id);
auto b_scale_struct = MakeBScale<1>(problem, p_b_scale_grid, block_n_id);
const index_t num_k_block_per_scale = GetKBlockPerScale();
@@ -594,6 +775,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
decltype(bs_grid_desc_bk0_n_bk1),
decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(e_grid_desc_mblock_mperblock_nblock_nperblock),
decltype(a_scale_struct),
decltype(b_scale_struct),
decltype(epilogue_args),
HasMainKBlockLoop,
@@ -613,8 +795,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
block_m_id,
block_n_id,
num_k_block_per_scale,
a_scale_struct,
b_scale_struct,
epilogue_args);
epilogue_args,
k_id);
}
// NOTE: Wrapper function to have __global__ function in common
@@ -626,7 +810,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
__device__ static void Run(void* p_shared,
const SplitKBatchOffset& splitk_batch_offset,
Argument& karg,
EpilogueArgument& epilogue_args)
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
{
// shift A matrices pointer for splitk
AsGridPointer p_as_grid_splitk;
@@ -644,18 +829,40 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale
splitk_batch_offset.b_k_split_offset[i];
});
const AScaleType* p_a_scale_grid_ptr;
if constexpr(ck::is_same_v<AScaleType, void>)
{
p_a_scale_grid_ptr = karg.p_a_scale_grid;
}
else
{
p_a_scale_grid_ptr = karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset;
}
const BScaleType* p_b_scale_grid_ptr;
if constexpr(ck::is_same_v<BScaleType, void>)
{
p_b_scale_grid_ptr = karg.p_b_scale_grid;
}
else
{
p_b_scale_grid_ptr = karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset;
}
Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_as_grid_splitk,
p_bs_grid_splitk,
karg.p_ds_grid,
karg.p_e_grid + splitk_batch_offset.c_reduce_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset,
p_a_scale_grid_ptr,
p_b_scale_grid_ptr,
p_shared,
karg,
karg.a_element_op,
karg.b_element_op,
karg.cde_element_op,
epilogue_args);
epilogue_args,
k_id);
}
};

View File

@@ -69,6 +69,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
}
template <typename GridwiseGemm,
bool HasMainKBlockLoop,
InMemoryDataOperationEnum EGlobalMemoryDataOperation,
index_t MinimumOccupancy = 1,
TailNumber TailNum = TailNumber::Full>
__global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
#endif
kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg)
{
#if(defined(__gfx11__) || defined(__gfx12__))
#if defined(__gfx11__)
// gfx11 does not support *_atomic_pk_add_f16/bf16 instructions
using e_data_type = remove_cvref_t<remove_pointer_t<decltype(karg.p_e_grid)>>;
if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd &&
(std::is_same_v<e_data_type, ck::half_t> ||
std::is_same_v<e_data_type, ck::bhalf_t>)))
{
#endif
constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte<
typename GridwiseGemm::EpilogueCShuffle>();
__shared__ char p_shared[LDS_size];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z);
const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack);
const index_t k_id = blockIdx.z * num_k_per_block;
auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{};
GridwiseGemm::template Run<HasMainKBlockLoop, EGlobalMemoryDataOperation, TailNum>(
p_shared, splitk_batch_offset, karg, epilogue_args, k_id);
#if defined(__gfx11__)
}
#endif
#else
ignore = karg;
#endif
}
template <typename ALayout,
typename BLayout,
typename DsLayout,
@@ -162,7 +204,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
static constexpr index_t KInner = IsBPreShuffled ? KInnerB : ck::math::min(KInnerA, KInnerB);
static constexpr index_t KPack =
KInner *
@@ -966,6 +1008,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
typename BGridDesc_BK0_N_K1,
typename DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
typename AScaleStruct,
typename BScaleStruct,
typename EpilogueArgument,
bool HasMainKBlockLoop,
@@ -988,6 +1031,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
const index_t& block_m_id,
const index_t& block_n_id,
const index_t& num_k_block_per_scale,
AScaleStruct& a_scale_struct,
BScaleStruct& b_scale_struct,
EpilogueArgument& epilogue_args,
const index_t k_id = 0)
@@ -1072,6 +1116,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
b_block_buf,
b_block_slice_copy_step,
c_thread_buf,
a_scale_struct,
b_scale_struct,
num_k_block_main_loop,
num_k_block_per_scale);

View File

@@ -43,13 +43,15 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy)
{
__shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()];
auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg);
GridwiseGemm::template Run<HasMainKBlockLoop, CGlobalMemoryDataOperation, TailNum>(
karg.p_a_grid,
karg.p_b_grid,
karg.p_a_grid + splitk_batch_offset.a_k_split_offset,
karg.p_b_grid + splitk_batch_offset.b_k_split_offset,
karg.p_ds_grid,
karg.p_c_grid,
karg.p_a_scale_grid,
karg.p_b_scale_grid,
karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset,
karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset,
p_shared,
karg,
karg.a_element_op,
@@ -405,31 +407,33 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
}
}
__host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K)
__host__ __device__ static constexpr auto
MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA)
{
const auto BM = math::integer_divide_ceil(M, ScaleBlockM);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1));
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1));
}
else if constexpr(is_same<tensor_layout::gemm::ColumnMajor, ALayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM));
return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA));
}
}
__host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K)
__host__ __device__ static constexpr auto
MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB)
{
const auto BN = math::integer_divide_ceil(N, ScaleBlockN);
const auto BK = math::integer_divide_ceil(K, ScaleBlockK);
if constexpr(is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1));
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1));
}
else if constexpr(is_same<tensor_layout::gemm::RowMajor, BLayout>::value)
{
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN));
return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB));
}
}
@@ -548,6 +552,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t StrideScaleA_,
index_t StrideScaleB_,
index_t KBatch_)
: M{M_},
N{N_},
@@ -556,6 +562,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
StrideB{StrideB_},
StrideDs{StrideDs_},
StrideC{StrideC_},
StrideScaleA{StrideScaleA_},
StrideScaleB{StrideScaleB_},
KBatch{KBatch_},
MPadded{CalculateMPadded(M_)},
NPadded{CalculateNPadded(N_)},
@@ -585,7 +593,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB;
std::array<index_t, NumDTensor> StrideDs;
index_t StrideC;
index_t StrideScaleA;
index_t StrideScaleB;
index_t KBatch;
index_t MPadded;
index_t NPadded;
@@ -611,13 +620,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t StrideB_,
std::array<index_t, NumDTensor> StrideDs_,
index_t StrideC_,
index_t StrideScaleA_,
index_t StrideScaleB_,
const AScaleType* p_a_scale_grid_,
const BScaleType* p_b_scale_grid_,
index_t k_batch_,
AElementwiseOperation a_element_op_,
BElementwiseOperation b_element_op_,
CElementwiseOperation c_element_op_)
: Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_},
: Problem{M_,
N_,
K_,
StrideA_,
StrideB_,
StrideDs_,
StrideC_,
StrideScaleA_,
StrideScaleB_,
k_batch_},
p_a_grid{p_a_grid_},
p_b_grid{p_b_grid_},
p_ds_grid{},
@@ -673,6 +693,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
b_k_split_offset = blockIdx.z * karg.KRead;
}
// Calculate A scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, ALayout>)
{
scale_a_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout>)
{
scale_a_k_split_offset =
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleA;
}
// Calculate B scale offset
if constexpr(is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
scale_b_k_split_offset =
blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleB;
}
else if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, BLayout>)
{
scale_b_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK);
}
if(blockIdx.z < static_cast<uint32_t>(karg.KBatch - 1))
{
karg.K = karg.KRead;
@@ -685,6 +727,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
index_t a_k_split_offset;
index_t b_k_split_offset;
index_t scale_a_k_split_offset; // A scale matrix offset
index_t scale_b_k_split_offset; // B scale matrix offset
};
__device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1()
@@ -1221,8 +1265,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N<CLayout>(
problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC);
const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K);
const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K);
const auto a_scale_grid_desc_am_ak =
MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA);
const auto b_scale_grid_desc_bn_ak =
MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB);
const auto c_grid_desc_mblock_mperblock_nblock_nperblock =
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(

View File

@@ -380,236 +380,143 @@ struct sequence_reduce<Reduce, Seq>
};
#endif
template <typename Values, typename Ids, typename Compare>
struct sequence_sort_impl
// Implement sequence_sort and sequence_unique_sort using constexpr functions (C++17)
namespace sort_impl {
// Temporary arrays to hold values during operations with capacity N and mutable size.
template <index_t N>
struct IndexedValueArray
{
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl
{
static constexpr bool choose_left = LeftValues::Front() < RightValues::Front();
static constexpr index_t chosen_value =
choose_left ? LeftValues::Front() : RightValues::Front();
static constexpr index_t chosen_id = choose_left ? LeftIds::Front() : RightIds::Front();
using new_merged_values = decltype(MergedValues::PushBack(Number<chosen_value>{}));
using new_merged_ids = decltype(MergedIds::PushBack(Number<chosen_id>{}));
using new_left_values =
typename conditional<choose_left, decltype(LeftValues::PopFront()), LeftValues>::type;
using new_left_ids =
typename conditional<choose_left, decltype(LeftIds::PopFront()), LeftIds>::type;
using new_right_values =
typename conditional<choose_left, RightValues, decltype(RightValues::PopFront())>::type;
using new_right_ids =
typename conditional<choose_left, RightIds, decltype(RightIds::PopFront())>::type;
using merge = sorted_sequence_merge_impl<new_left_values,
new_left_ids,
new_right_values,
new_right_ids,
new_merged_values,
new_merged_ids,
Comp>;
// this is output
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
template <typename LeftValues,
typename LeftIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<LeftValues,
LeftIds,
Sequence<>,
Sequence<>,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, LeftValues>::type;
using merged_ids = typename sequence_merge<MergedIds, LeftIds>::type;
};
template <typename RightValues,
typename RightIds,
typename MergedValues,
typename MergedIds,
typename Comp>
struct sorted_sequence_merge_impl<Sequence<>,
Sequence<>,
RightValues,
RightIds,
MergedValues,
MergedIds,
Comp>
{
using merged_values = typename sequence_merge<MergedValues, RightValues>::type;
using merged_ids = typename sequence_merge<MergedIds, RightIds>::type;
};
template <typename LeftValues,
typename LeftIds,
typename RightValues,
typename RightIds,
typename Comp>
struct sorted_sequence_merge
{
using merge = sorted_sequence_merge_impl<LeftValues,
LeftIds,
RightValues,
RightIds,
Sequence<>,
Sequence<>,
Comp>;
using merged_values = typename merge::merged_values;
using merged_ids = typename merge::merged_ids;
};
static constexpr index_t nsize = Values::Size();
using split_unsorted_values = sequence_split<Values, nsize / 2>;
using split_unsorted_ids = sequence_split<Ids, nsize / 2>;
using left_unsorted_values = typename split_unsorted_values::left_type;
using left_unsorted_ids = typename split_unsorted_ids::left_type;
using left_sort = sequence_sort_impl<left_unsorted_values, left_unsorted_ids, Compare>;
using left_sorted_values = typename left_sort::sorted_values;
using left_sorted_ids = typename left_sort::sorted_ids;
using right_unsorted_values = typename split_unsorted_values::right_type;
using right_unsorted_ids = typename split_unsorted_ids::right_type;
using right_sort = sequence_sort_impl<right_unsorted_values, right_unsorted_ids, Compare>;
using right_sorted_values = typename right_sort::sorted_values;
using right_sorted_ids = typename right_sort::sorted_ids;
using merged_sorted = sorted_sequence_merge<left_sorted_values,
left_sorted_ids,
right_sorted_values,
right_sorted_ids,
Compare>;
using sorted_values = typename merged_sorted::merged_values;
using sorted_ids = typename merged_sorted::merged_ids;
index_t values[N > 0 ? N : 1];
index_t ids[N > 0 ? N : 1];
index_t size = 0;
};
template <index_t ValueX, index_t ValueY, index_t IdX, index_t IdY, typename Compare>
struct sequence_sort_impl<Sequence<ValueX, ValueY>, Sequence<IdX, IdY>, Compare>
template <index_t... Is>
constexpr auto make_indexed_value_array(Sequence<Is...>)
{
static constexpr bool choose_x = Compare{}(ValueX, ValueY);
constexpr index_t N = sizeof...(Is);
IndexedValueArray<N> result = {{Is...}, {}, N};
for(index_t i = 0; i < N; ++i)
{
result.ids[i] = i;
}
return result;
}
using sorted_values =
typename conditional<choose_x, Sequence<ValueX, ValueY>, Sequence<ValueY, ValueX>>::type;
using sorted_ids = typename conditional<choose_x, Sequence<IdX, IdY>, Sequence<IdY, IdX>>::type;
enum class SortField
{
Values,
Ids
};
template <index_t Value, index_t Id, typename Compare>
struct sequence_sort_impl<Sequence<Value>, Sequence<Id>, Compare>
// Perform an insertion sort on an IndexedValueArray.
template <index_t N, typename Compare>
constexpr auto insertion_sort(IndexedValueArray<N> arr, Compare comp)
{
using sorted_values = Sequence<Value>;
using sorted_ids = Sequence<Id>;
for(index_t i = 1; i < arr.size; ++i)
{
index_t key_val = arr.values[i];
index_t key_id = arr.ids[i];
index_t j = i - 1;
while(j >= 0 && comp(key_val, arr.values[j]))
{
arr.values[j + 1] = arr.values[j];
arr.ids[j + 1] = arr.ids[j];
--j;
}
arr.values[j + 1] = key_val;
arr.ids[j + 1] = key_id;
}
return arr;
}
// Remove duplicates from a sorted IndexedValueArray.
template <index_t N, typename Equal>
constexpr auto unique(const IndexedValueArray<N>& sorted, Equal eq)
{
IndexedValueArray<N> result{};
if constexpr(N == 0)
{
return result;
}
result.size = 1;
result.values[0] = sorted.values[0];
result.ids[0] = sorted.ids[0];
for(index_t i = 1; i < sorted.size; ++i)
{
if(!eq(sorted.values[i], sorted.values[i - 1]))
{
result.values[result.size] = sorted.values[i];
result.ids[result.size] = sorted.ids[i];
++result.size;
}
}
return result;
}
// Compute sorted (and optionally unique) IndexedValueArray from input Sequence.
template <bool Unique, typename Compare, typename Equal, index_t... Is>
constexpr auto compute_sorted(Sequence<Is...> seq, Compare comp, Equal eq)
{
auto sorted = insertion_sort(make_indexed_value_array(seq), comp);
return Unique ? unique(sorted, eq) : sorted;
}
// Cache the sorted results to avoid recomputation.
template <bool Unique, typename Seq, typename Compare, typename Equal>
struct SortedCache
{
static constexpr auto data = compute_sorted<Unique>(Seq{}, Compare{}, Equal{});
};
template <typename Compare>
struct sequence_sort_impl<Sequence<>, Sequence<>, Compare>
// Build sorted value and ID sequences from cached sorted data
template <SortField Field, bool Unique, typename Seq, typename Compare, typename Equal, index_t I>
constexpr index_t get_sorted_field()
{
using sorted_values = Sequence<>;
using sorted_ids = Sequence<>;
constexpr auto& data = SortedCache<Unique, Seq, Compare, Equal>::data;
return (Field == SortField::Values) ? data.values[I] : data.ids[I];
}
template <bool Unique, typename Seq, typename Compare, typename Equal, typename IndexSeq>
struct SortedSequences;
template <bool Unique, typename Seq, typename Compare, typename Equal, index_t... Is>
struct SortedSequences<Unique, Seq, Compare, Equal, Sequence<Is...>>
{
using values_type =
Sequence<get_sorted_field<SortField::Values, Unique, Seq, Compare, Equal, Is>()...>;
using ids_type =
Sequence<get_sorted_field<SortField::Ids, Unique, Seq, Compare, Equal, Is>()...>;
};
template <bool Unique, typename Seq, typename Compare, typename Equal>
using sorted_sequences_t = SortedSequences<
Unique,
Seq,
Compare,
Equal,
typename arithmetic_sequence_gen<0, SortedCache<Unique, Seq, Compare, Equal>::data.size, 1>::
type>;
using Equal = ck::math::equal<index_t>;
} // namespace sort_impl
template <typename Values, typename Compare>
struct sequence_sort
{
using unsorted_ids = typename arithmetic_sequence_gen<0, Values::Size(), 1>::type;
using sort = sequence_sort_impl<Values, unsorted_ids, Compare>;
// this is output
using type = typename sort::sorted_values;
using sorted2unsorted_map = typename sort::sorted_ids;
using sorted_seqs = sort_impl::sorted_sequences_t<false, Values, Compare, sort_impl::Equal>;
using type = typename sorted_seqs::values_type;
using sorted2unsorted_map = typename sorted_seqs::ids_type;
};
template <typename Values, typename Less, typename Equal>
struct sequence_unique_sort
{
template <typename RemainValues,
typename RemainIds,
typename UniquifiedValues,
typename UniquifiedIds,
typename Eq>
struct sorted_sequence_uniquify_impl
{
static constexpr index_t current_value = RemainValues::Front();
static constexpr index_t current_id = RemainIds::Front();
static constexpr bool is_unique_value = (current_value != UniquifiedValues::Back());
using new_remain_values = decltype(RemainValues::PopFront());
using new_remain_ids = decltype(RemainIds::PopFront());
using new_uniquified_values =
typename conditional<is_unique_value,
decltype(UniquifiedValues::PushBack(Number<current_value>{})),
UniquifiedValues>::type;
using new_uniquified_ids =
typename conditional<is_unique_value,
decltype(UniquifiedIds::PushBack(Number<current_id>{})),
UniquifiedIds>::type;
using uniquify = sorted_sequence_uniquify_impl<new_remain_values,
new_remain_ids,
new_uniquified_values,
new_uniquified_ids,
Eq>;
// this is output
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
template <typename UniquifiedValues, typename UniquifiedIds, typename Eq>
struct sorted_sequence_uniquify_impl<Sequence<>,
Sequence<>,
UniquifiedValues,
UniquifiedIds,
Eq>
{
using uniquified_values = UniquifiedValues;
using uniquified_ids = UniquifiedIds;
};
template <typename SortedValues, typename SortedIds, typename Eq>
struct sorted_sequence_uniquify
{
using uniquify = sorted_sequence_uniquify_impl<decltype(SortedValues::PopFront()),
decltype(SortedIds::PopFront()),
Sequence<SortedValues::Front()>,
Sequence<SortedIds::Front()>,
Eq>;
using uniquified_values = typename uniquify::uniquified_values;
using uniquified_ids = typename uniquify::uniquified_ids;
};
using sort = sequence_sort<Values, Less>;
using sorted_values = typename sort::type;
using sorted_ids = typename sort::sorted2unsorted_map;
using uniquify = sorted_sequence_uniquify<sorted_values, sorted_ids, Equal>;
// this is output
using type = typename uniquify::uniquified_values;
using sorted2unsorted_map = typename uniquify::uniquified_ids;
using sorted_seqs = sort_impl::sorted_sequences_t<true, Values, Less, Equal>;
using type = typename sorted_seqs::values_type;
using sorted2unsorted_map = typename sorted_seqs::ids_type;
};
template <typename SeqMap>

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/algorithm/cluster_descriptor.hpp"

View File

@@ -1550,9 +1550,10 @@ CK_TILE_DEVICE thread_buffer<T, N> amd_buffer_load_impl(int32x4_t src_wave_buffe
(std::is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, e8m0_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_int4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32) ||
(std::is_same<T, pk_fp4_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16))),
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16 || N == 32)) ||
(std::is_same<T, pk_fp4_raw_t>::value &&
(N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
(std::is_same<T, pk_fp4_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
"wrong! not implemented");
using rtn_type = thread_buffer<T, N>;

View File

@@ -39,8 +39,12 @@
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#if __clang_major__ < 22
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#else
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline

View File

@@ -41,9 +41,8 @@ struct scales
Scale lhs_;
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename Scale>
__host__ __device__ scales(Scale) -> scales<Scale>;
CK_TILE_HOST_DEVICE_EXTERN scales(Scale) -> scales<Scale>;
template <typename Left = void, typename Right = Left>
struct plus
@@ -66,8 +65,7 @@ struct plus<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ plus() -> plus<void, void>;
CK_TILE_HOST_DEVICE_EXTERN plus() -> plus<void, void>;
template <typename Left = void, typename Right = Left>
struct minus
@@ -90,8 +88,7 @@ struct minus<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ minus() -> minus<void, void>;
CK_TILE_HOST_DEVICE_EXTERN minus() -> minus<void, void>;
template <typename Left = void, typename Right = Left>
struct multiplies
@@ -114,8 +111,7 @@ struct multiplies<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ multiplies() -> multiplies<void, void>;
CK_TILE_HOST_DEVICE_EXTERN multiplies() -> multiplies<void, void>;
template <typename T>
struct maximize
@@ -345,8 +341,7 @@ struct equal<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ equal() -> equal<void, void>;
CK_TILE_HOST_DEVICE_EXTERN equal() -> equal<void, void>;
template <>
struct equal<float, float>
@@ -387,8 +382,7 @@ struct less<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ less() -> less<void, void>;
CK_TILE_HOST_DEVICE_EXTERN less() -> less<void, void>;
template <typename Left = void, typename Right = Left>
struct less_equal
@@ -411,8 +405,7 @@ struct less_equal<void, void>
}
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
__host__ __device__ less_equal() -> less_equal<void, void>;
CK_TILE_HOST_DEVICE_EXTERN less_equal() -> less_equal<void, void>;
template <>
struct less_equal<float, float>

View File

@@ -47,9 +47,8 @@ struct composes<F>
F f_;
};
/// FIXME: create macro to replace '__host__ __device__' and nothing more
template <typename... Ts>
__host__ __device__ composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
CK_TILE_HOST_DEVICE_EXTERN composes(Ts&&...) -> composes<remove_cvref_t<Ts>...>;
template <typename SaturateType>
struct saturates

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/host/arg_parser.hpp"

View File

@@ -52,9 +52,19 @@ template <typename ComputeDataType, typename OutDataType, typename AccDataType =
CK_TILE_HOST double get_relative_threshold(const int number_of_accumulations = 1)
{
static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
static_assert(is_any_of<ComputeDataType,
F8,
BF8,
F16,
BF16,
F32,
pk_fp4_t,
pk_fp4_raw_t,
pk_int4_t,
I8,
I32,
int>::value,
"Warning: Unhandled ComputeDataType for setting up the relative threshold!");
double compute_error = 0;
if constexpr(is_any_of<ComputeDataType, pk_int4_t, I8, I32, int>::value)
@@ -113,9 +123,19 @@ CK_TILE_HOST double get_absolute_threshold(const double max_possible_num,
const int number_of_accumulations = 1)
{
static_assert(
is_any_of<ComputeDataType, F8, BF8, F16, BF16, F32, pk_int4_t, I8, I32, int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
static_assert(is_any_of<ComputeDataType,
F8,
BF8,
F16,
BF16,
F32,
pk_fp4_t,
pk_fp4_raw_t,
pk_int4_t,
I8,
I32,
int>::value,
"Warning: Unhandled ComputeDataType for setting up the absolute threshold!");
auto expo = std::log2(std::abs(max_possible_num));
double compute_error = 0;

View File

@@ -372,6 +372,63 @@ CK_TILE_HOST void reference_gemm_tensor_quant(const HostTensor<ADataType>& a_m_k
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
}
template <typename ADataType,
typename QDataType,
typename BDataType,
typename AccDataType,
typename CDataType,
typename QuantGroupSize,
bool aquant,
typename AElementOp = ck_tile::identity,
typename BElementOp = ck_tile::identity,
typename ACCElementOp = ck_tile::identity>
CK_TILE_HOST void reference_mxfp4gemm_quant(const HostTensor<ADataType>& a_m_k,
const HostTensor<QDataType>& q,
const HostTensor<BDataType>& b_k_n,
HostTensor<CDataType>& c_m_n,
const AElementOp& a_element_op = {},
const BElementOp& b_element_op = {},
const ACCElementOp& acc_element_op = {})
{
const std::size_t M = a_m_k.get_length(0);
const std::size_t N = b_k_n.get_length(1);
const std::size_t K = a_m_k.get_length(1);
auto f_mn = [&](auto m, auto n) {
AccDataType v_acc = 0;
AccDataType pasual = 0;
for(std::size_t k = 0; k < (K / 2); k++)
{
using ComputeType = float;
auto b_scale = type_convert<int32_t>(q((2 * k) / QuantGroupSize::kK, n)) - 127;
ComputeType v_a_0, v_a_1;
ComputeType v_b_0, v_b_1;
v_a_0 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k))));
v_a_1 = ck_tile::type_convert<ComputeType>((a_element_op(a_m_k(m, 2 * k + 1))));
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
{
auto b_pack = type_convert<pk_fp4_t>(b_element_op(b_k_n(k, n)));
auto b_scale_fp4 = type_convert<float>(std::pow(2.0f, b_scale));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
v_b_0 = type_convert<ComputeType>(b_f4_lo) * b_scale_fp4;
v_b_1 = type_convert<ComputeType>(b_f4_hi) * b_scale_fp4;
}
pasual = v_a_0 * v_b_0 + v_a_1 * v_b_1;
v_acc += pasual;
}
c_m_n(m, n) = ck_tile::type_convert<CDataType>(acc_element_op(v_acc));
};
make_ParallelTensorFunctor(f_mn, M, N)(std::thread::hardware_concurrency());
std::cout << std::endl;
}
template <typename ADataType,
typename BDataType,
typename AccDataType,

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_contraction/kernel/batched_contraction_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"

View File

@@ -22,6 +22,7 @@ template <> struct DataTypeTraits<bf8_t> { static constexpr const char * name =
template <> struct DataTypeTraits<int8_t> { static constexpr const char * name = "int8"; };
template <> struct DataTypeTraits<pk_int4_t> { static constexpr const char * name = "pk_int4"; };
template <> struct DataTypeTraits<pk_fp4_t> { static constexpr const char * name = "pk_fp4"; };
template <> struct DataTypeTraits<pk_fp4_raw_t> { static constexpr const char * name = "pk_fp4_raw"; };
template <memory_operation_enum MemOp> struct memOpToStr;
template <> struct memOpToStr<memory_operation_enum::set> { static constexpr const char * name = "set"; };

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"

View File

@@ -92,11 +92,17 @@ struct CShuffleEpilogue
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using ATypeToUse = std::conditional_t<std::is_same_v<ADataType, pk_int4_t> ||
std::is_same_v<ADataType, pk_fp4_t>,
BDataType,
ADataType>;
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
ADataType,
BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/flatmm/block/block_flatmm_asmem_bsmem_creg_v1.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/fused_moegemm_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"

View File

@@ -96,8 +96,10 @@ struct BlockUniversalGemmAsBsCr
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
ADataType,
BDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;

View File

@@ -17,10 +17,12 @@ struct GemmPipelineAgBgCrImplBase
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BDataType =
std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>, ADataType, BInDataType>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
@@ -270,12 +272,17 @@ struct GemmPipelineAgBgCrImplBase
}();
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
using BLdsDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
auto b_lds_load_tile_distr = []() {
if constexpr(is_b_load_tr)
return make_static_tile_distribution(
typename InputTileDistributionTraits<
typename BLdsLoadTileDistr::DstrEncode,
typename Problem::BDataType>::TransposedDstrEncode{});
typename InputTileDistributionTraits<typename BLdsLoadTileDistr::DstrEncode,
BLdsDataType>::TransposedDstrEncode{});
else
return BLdsLoadTileDistr{};
}();

View File

@@ -303,8 +303,11 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
@@ -585,9 +588,12 @@ struct UniversalGemmBasePolicy
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
using BDataType = std::conditional_t<std::is_same_v<BInDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
if constexpr(Problem::FixedVectorSize)
{
@@ -729,13 +735,17 @@ struct UniversalGemmBasePolicy
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t KPerBlock = std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
? Problem::BlockGemmShape::kK / 2
: Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
std::is_same_v<BDataType, ck_tile::pk_fp4_raw_t>
? 4
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
using BLayout = remove_cvref_t<
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
@@ -841,10 +851,12 @@ struct UniversalGemmBasePolicy
template <typename Problem>
CK_TILE_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b =
integer_least_multiple(sizeof(typename Problem::BDataType) *
Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK,
16);
using BDataType =
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>;
constexpr index_t smem_size_b = integer_least_multiple(
sizeof(BDataType) * Problem::BlockGemmShape::kN * Problem::BlockGemmShape::kK, 16);
return smem_size_b;
}
@@ -882,8 +894,10 @@ struct UniversalGemmPipelineAgBgCrPolicy
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using ATypeToUse =
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
std::is_same_v<BDataType, pk_fp4_raw_t>,
ADataType,
BDataType>;
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
BTypeToUse,

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
@@ -21,6 +20,9 @@
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_bquant_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_group_quant_utils.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_v3.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_quant_pipeline_problem.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_base_policy.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_wp_bquant_pipeline_ag_bg_cr_v2.hpp"

View File

@@ -759,12 +759,20 @@ struct QuantGemmKernel
}
else
{
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k / 2),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
else
return make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N, splitk_batch_offset.splitted_k),
make_tuple(kargs.stride_B, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
}
}
}
@@ -900,10 +908,16 @@ struct QuantGemmKernel
const auto& b_tensor_view = views.at(I2);
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
sequence<false, GemmPipeline::kPadK>{});
else
return pad_tensor_view(b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
@@ -1046,10 +1060,17 @@ struct QuantGemmKernel
{
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
if constexpr(std::is_same_v<BDataType, pk_fp4_raw_t>)
return make_tile_window(
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock / 2>{}),
{i_n, 0});
else
return make_tile_window(b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{},
number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
}
else
{

View File

@@ -0,0 +1,59 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
namespace ck_tile {
template <typename Problem, typename Policy>
struct GemmMxFp4PipelineAgBgCrImplBase : public GemmPipelineAgBgCrImplBase<Problem, Policy>
{
using Base = GemmPipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = typename Base::ADataType;
using ALayout = typename Base::ALayout;
using BDataType = typename Base::BDataType;
using BLayout = typename Base::BLayout;
using BlockGemmShape = typename Base::BlockGemmShape;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = NPerBlock / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = KPerBlock / QuantGroupSize::kK;
static_assert(NPerBlockBQ >= 1, "NPerBlock must be >= QuantGroupSize");
static_assert(KPerBlockBQ >= 1, "KPerBlock must be >= QuantGroupSize");
static_assert(NPerBlock % QuantGroupSize::kN == 0,
"NPerBlock must be a multiple of QuantGroupSize::kN");
static_assert(KPerBlock % QuantGroupSize::kK == 0,
"KPerBlock must be a multiple of QuantGroupSize::kK");
// Create DRAM tile window for BQ
template <typename BQDramBlockWindowTmp>
CK_TILE_DEVICE constexpr auto
GetBQDramLoadWindow(const BQDramBlockWindowTmp& bq_dram_block_window_tmp) const
{
static_assert(std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>);
using YPerTile = number<NPerBlockBQ>;
using XPerTile = number<KPerBlockBQ>;
auto bq_copy_dram_window =
make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
make_tuple(YPerTile(), XPerTile()),
bq_dram_block_window_tmp.get_window_origin(),
Policy::template MakeBQDramTileDistribution<Problem>());
return bq_copy_dram_window;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,140 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "gemm_group_quant_utils.hpp"
namespace ck_tile {
struct GemmMxFp4PipelineAgBgCrPolicy : public UniversalGemmPipelineAgBgCrPolicy
{
using Base = UniversalGemmPipelineAgBgCrPolicy;
using Base::I0;
using Base::I1;
using Base::I2;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeBQ()
{
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlockBQ = NPerBlock / Problem::QuantGroupSize::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlockBQ = KPerBlock / Problem::QuantGroupSize::kK;
static_assert(std::is_same_v<BQLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
return GetABQGlobalVectorLoadSize<Problem, BQDataType, NPerBlockBQ, KPerBlockBQ>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBRegTileDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
// Tile: NPerBlock X KPerBlock
else
{
using TileEncodingPattern =
tile_distribution_encoding_pattern_2d<BlockSize,
NPerBlock,
KPerBlock,
VecLoadSize,
getBTileAccessPattern(),
NumWaveGroups>;
return TileEncodingPattern::make_2d_static_tile_distribution();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBQDramTileDistribution()
{
// using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KScale = KPerBlock / Problem::QuantGroupSize::kK; // k_scale num //2
constexpr index_t VecLoadSize =
Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>();
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
constexpr index_t warp_size = get_warp_size();
constexpr index_t num_warps = BlockSize / get_warp_size();
constexpr index_t LargestVec = (KPerBlock * NPerBlock) / (num_warps * warp_size);
constexpr index_t b_vec = VecLoadSize > LargestVec ? LargestVec : VecLoadSize;
constexpr index_t K0 = KPerBlock / b_vec;
constexpr index_t K1 = K0 / KScale;
constexpr index_t K3 = K0 / K1;
constexpr index_t K2 = 1;
constexpr index_t N0 = num_warps / NumWaveGroups;
constexpr index_t N1 = warp_size / K0;
constexpr index_t N2 = NPerBlock / (N0 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<K1>,
tuple<sequence<N0, N1, N2>, sequence<K3, K2>>,
tuple<sequence<1>, sequence<1, 2, 0>>,
tuple<sequence<0>, sequence<1, 0, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
static_assert(Problem::QuantGroupSize::kK % WarpTile::at(I2) == 0,
"KPerWarpGemm must be a multiple of QuantGroupSize!");
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
typename Problem::ComputeDataType,
typename Problem::CDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
Problem::TransposeC>;
static_assert(std::is_same_v<typename Problem::ComputeDataType, fp8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf8_t> ||
std::is_same_v<typename Problem::ComputeDataType, bf16_t>);
static_assert(std::is_same_v<typename Problem::CDataType, float>);
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<
typename Problem::ADataType,
std::conditional_t<std::is_same_v<typename Problem::BDataType, pk_fp4_raw_t>,
typename Problem::ADataType,
typename Problem::BDataType>,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,665 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
#include <sstream>
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm_quant/pipeline/gemm_mxfp4_pipeline_ag_bg_cr_base.hpp"
#include "ck_tile/host/concat.hpp"
namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem, typename Policy = GemmMxFp4PipelineAgBgCrPolicy>
struct MxFp4GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
{
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmMxFp4PipelineAgBgCrImplBase<Problem, Policy>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BDqDataType = remove_cvref_t<typename Problem::ADataType>;
using BQDataType = remove_cvref_t<typename Problem::BQDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
using QuantGroupSize = remove_cvref_t<typename Problem::QuantGroupSize>;
static_assert(QuantGroupSize::kM == 1, "only N/K blocks for BQuant kernel!");
using I0 = number<0>;
using I1 = number<1>;
using I2 = number<2>;
static constexpr index_t APackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
static constexpr index_t BPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDqDataType>>::PackedSize;
static constexpr index_t BQPackedSize =
ck_tile::numeric_traits<remove_cvref_t<BQDataType>>::PackedSize;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BQLayout = remove_cvref_t<typename Problem::BQLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
static constexpr index_t BlockSize = Problem::kBlockSize;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t NPerBlockBQ = BlockGemmShape::kN / QuantGroupSize::kN;
static constexpr index_t KPerBlockBQ = BlockGemmShape::kK / QuantGroupSize::kK;
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetVectorSizeBQ()
{
return Policy::template GetVectorSizeBQ<Problem>();
}
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler;
using Base::PrefetchStages;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
return concat('_', "mxfp4gemm_pipeline_AgBgCrCompV3",
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
concat('x', WaveNumM, WaveNumN),
concat('x', kPadM, kPadN, kPadK),
concat('x', kPadM, kPadN, kPadK), QuantGroupSize::GetName());
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
CK_TILE_HOST static std::string Print()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t BQ_Buffer_Load_Inst_Num =
NPerBlock * KPerBlockBQ / (BlockSize * GetVectorSizeBQ());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
auto str = std::stringstream{};
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << ", "
<< "BQ vector size: " << GetVectorSizeBQ() << "\n"
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
<< ", " << "BQ buffer load inst: " << BQ_Buffer_Load_Inst_Num << "\n"
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
<< "\n"
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
<< "QuantGroupSize: " << QuantGroupSize::GetName() << "\n"
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
<< "PrefetchStages: " << PrefetchStages << "\n";
return str.str();
}
template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase
{
};
template <>
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
{
using Base = PipelineImplBase;
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
{
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
constexpr index_t WaveSize = 64;
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
// Below should be equal to AK1|BK1
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
constexpr index_t A_Buffer_Load_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
constexpr index_t B_Buffer_Load_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
constexpr index_t A_LDS_Write_Inst_Num =
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
constexpr index_t B_LDS_Write_Inst_Num =
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
constexpr index_t A_LDS_Read_Inst_Num =
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
constexpr index_t B_LDS_Read_Inst_Num =
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
(BlockSize / WaveSize) /
(MPerXDL * NPerXDL * KPerXDL);
// A/B split schedule
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
constexpr auto num_ds_read_inst_a =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
: A_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_read_inst_b =
B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16
? B_LDS_Read_Inst_Num
: B_LDS_Read_Inst_Num / 2;
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
constexpr auto ds_read_a_issue_cycle =
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
constexpr auto ds_read_b_issue_cycle =
B_LDS_Read_Width * sizeof(BDqDataType) / BPackedSize == 16 ? 8 : 4;
constexpr auto ds_read_a_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
constexpr auto ds_read_b_mfma_rate =
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
constexpr auto num_dsread_a_mfma =
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
constexpr auto num_dsread_b_mfma =
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) /
// sizeof(BDataType)
// ? sizeof(ComputeDataType) /
// sizeof(ADataType) : sizeof(ComputeDataType)
// / sizeof(BDataType);
constexpr auto num_mfma_stage1 =
num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
});
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
ignore = i;
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
ignore = idswrite;
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
__builtin_amdgcn_sched_group_barrier(
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
});
// stage 2
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
ds_read_a_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
ds_read_b_mfma_rate)
{
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
}
else
{
__builtin_amdgcn_sched_group_barrier(
0x100,
num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
0); // DS read
}
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
});
}
template <bool HasHotLoop,
TailNumber TailNum,
typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BQDramBlockWindowTmp,
typename AElementFunction,
typename BElementFunction>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const AElementFunction& a_element_func,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BElementFunction& b_element_func,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem) const
{
static_assert(
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
std::is_same_v<BDataType,
remove_cvref_t<typename BDramBlockWindowTmp::DataType>> &&
std::is_same_v<BQDataType,
remove_cvref_t<typename BQDramBlockWindowTmp::DataType>>,
"A/B/BQ Dram block window should have the same data type as appropriate "
"([A|B|BQ]DataType) defined in Problem definition!");
constexpr bool is_a_col_major =
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_bq_col_major =
std::is_same_v<BQLayout, tensor_layout::gemm::ColumnMajor>;
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
static_assert(is_bq_col_major, "Bq must be col major (row major not supported yet)");
static_assert(NPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlockBQ == BQDramBlockWindowTmp{}.get_window_lengths()[I1{}],
"Bq block window has incorrect lengths for defined BqLayout!");
static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(
is_b_row_major
? (KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock / 2 == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// int b_block_stride = 0;
// A/B tiles in LDS
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
// Tile distribution for load from lds
constexpr auto a_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
constexpr auto b_lds_load_tile_distr =
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
// A DRAM tile window for load
// A LDS tile window for store
// A LDS tile for block GEMM
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
// B DRAM tile window for load, (kN, kK/2)
// B LDS tile window for store, (kN, kK)
// B LDS tile for block GEMM
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
// B scale DRAM tile window for load
// auto b_scale_copy_dram_window =
// make_tile_window(bq_dram_block_window_tmp.get_bottom_tensor_view(),
// bq_dram_block_window_tmp.get_window_lengths(),
// bq_dram_block_window_tmp.get_window_origin(),
// Policy::template GetBQDramLoadWindow<Problem>());
auto bq_copy_dram_window = Base::GetBQDramLoadWindow(bq_dram_block_window_tmp);
auto bq_block_tile = decltype(load_tile(bq_copy_dram_window)){};
// Block GEMM
auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
// using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
using ABlockTile =
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
using BBlockTile =
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
ABlockTile a_block_tile;
BBlockTile b_fp4_block_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock / 2, 0) : make_array(0, KPerBlock / 2);
constexpr index_t b_scale_dram_tile_window_step = KPerBlock / QuantGroupSize::kK;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
// auto a_scale_block_tile = decltype(load_tile(a_scale_copy_dram_window)){};
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
// BDataType
auto b_block_tile = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeBRegTileDistribution<Problem>());
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
constexpr auto idx1_js = tile_distributed_index<0>{};
constexpr auto b_block = decltype(b_fp4_block_tile)::get_distributed_spans();
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
auto b_scale_uint = type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
constexpr auto idx1_hi = tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
b_block_tile(i_j_idx_lo) =
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
b_block_tile(i_j_idx_hi) =
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
});
});
// initialize C
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
block_sync_lds();
// LDS write 0
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
auto b_scale_uint = type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
constexpr auto idx1_hi = tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
b_block_tile(i_j_idx_lo) =
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
b_block_tile(i_j_idx_hi) =
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
});
});
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
__builtin_amdgcn_sched_barrier(0);
// main body
if constexpr(HasHotLoop)
{
index_t i = 0;
do
{
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
Base::GlobalPrefetch(a_block_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_fp4_block_tile, b_copy_dram_window, b_dram_tile_window_step);
bq_block_tile = load_tile(bq_copy_dram_window);
move_tile_window(bq_copy_dram_window, {0, b_scale_dram_tile_window_step});
sweep_tile_span(b_block[number<0>{}], [&](auto idx0) {
sweep_tile_span(b_block[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
constexpr auto i_j_idx_scale = make_tuple(idx0, idx1_js);
auto b_scale_uint =
type_convert<int32_t>(bq_block_tile(i_j_idx_scale)) - 127;
auto b_scale = type_convert<float>(std::pow(2.0f, b_scale_uint));
constexpr auto idx1_lo = tile_distributed_index<idx1.impl_.at(0) * 2>{};
constexpr auto idx1_hi =
tile_distributed_index<idx1.impl_.at(0) * 2 + 1>{};
constexpr auto i_j_idx_lo = make_tuple(idx0, idx1_lo);
constexpr auto i_j_idx_hi = make_tuple(idx0, idx1_hi);
auto b_pack = type_convert<pk_fp4_t>(b_fp4_block_tile(i_j_idx));
auto b_f4_lo = type_convert<pk_fp4_t>(b_pack.unpack(number<0>{}));
auto b_f4_hi = type_convert<pk_fp4_t>(b_pack.unpack(number<1>{}));
b_block_tile(i_j_idx_lo) =
type_convert<bf16_t>(type_convert<float>(b_f4_lo) * b_scale);
b_block_tile(i_j_idx_hi) =
type_convert<bf16_t>(type_convert<float>(b_f4_hi) * b_scale);
});
});
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
i += 1;
// b_block_stride +=1;
} while(i < (num_loop - 1));
}
// tile_elementwise_inout([](auto& c) { c = 0; }, acc_block_tile);
// tail
if constexpr((TailNum == TailNumber::Full) || (TailNum == TailNumber::Odd))
{
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
// latency
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}
else
{
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_block_tile);
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window, a_block_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDqDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_block_tile);
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window, b_block_tile, b_element_func);
}
block_sync_lds();
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds();
}
__builtin_amdgcn_sched_barrier(0);
return c_block_tile;
}
};
/**
* @brief This function runs the pipeline using compile-time known hot loop and tail number.
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
* SplitK.
* @note This is used by the kernel variants that are able to determine
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
*/
template <typename ADramBlockWindowTmp,
typename BDramBlockWindowTmp,
typename BQDramBlockWindowTmp>
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
const BDramBlockWindowTmp& b_dram_block_window_tmp,
const BQDramBlockWindowTmp& bq_dram_block_window_tmp,
index_t num_loop,
void* p_smem,
index_t n = 0) const
{
ck_tile::ignore = n;
return PipelineImpl<Scheduler>{}.template operator()<HasHotLoop, TailNum>(
a_dram_block_window_tmp,
[](const ADataType& a) { return a; },
b_dram_block_window_tmp,
[](const BDqDataType& b) { return b; },
bq_dram_block_window_tmp,
num_loop,
p_smem);
}
};
} // namespace ck_tile

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/grouped_convolution/kernel/grouped_convolution_backward_data_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/reduce/block/block_reduce.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"

View File

@@ -1,6 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,5 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -1,7 +1,6 @@
# Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
# SPDX-License-Identifier: MIT
from datetime import datetime
import pathlib
from pathlib import Path
import subprocess
@@ -13,8 +12,8 @@ OPS = "ops"
OPS_COMMON = "common" # common header will be duplicated into ops/* other module
IGNORED_DIRS = ["utility", "ref"]
HEADER_COMMON = f"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n
HEADER_COMMON = """// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
"""