Extend support for ak1 / bk1 WMMA (#3073)

* Extend AK1 / BK1 support:

 - Add support for AK1 != BK1
 - Add support for AK1, BK1 > 8
 - Introduce KInner template parameter for pipelines when loading multiple tiles with one instruction

* fix clang format
This commit is contained in:
Enrico Degregori
2025-11-11 16:38:15 +01:00
committed by GitHub
parent 9f33b7cfd3
commit 1c544abf57
24 changed files with 621 additions and 339 deletions

View File

@@ -28,6 +28,7 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
constexpr auto BlockGemmPipeline_Selector()
{
@@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector()
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>{};
}
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
@@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector()
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>{};
}
else

View File

@@ -30,6 +30,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_base
{
@@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
static constexpr auto I2 = Number<2>{};
static constexpr auto I3 = Number<3>{};
static constexpr auto I5 = Number<5>{};
static constexpr auto I6 = Number<6>{};
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base
static constexpr index_t B_KRow = 1;
#endif
static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
ComputeTypeB,
AccDataType,
MPerWmma,
NPerWmma,
KPack / KInner,
TransposeC>{};
static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);
static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
static constexpr auto wmma_gemm =
WmmaGemm<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
static constexpr index_t KRepeat = KPerBlock / KPack;
static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
@@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
const auto wmma_krow = 0;
#endif
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
}
__device__ static auto CalculateBThreadOriginDataIndex()
@@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
const auto wmma_krow = 0;
#endif
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
}
template <index_t m0, index_t n0>
@@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
return make_tuple(c_thread_m, c_thread_n);
}
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
using Tuple7 = decltype(CalculateAThreadOriginDataIndex());
/**
* @brief Constructor for BlockwiseGemmWmmaops_pipeline_base.
@@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base
* repeat dimensions.
*/
__host__ __device__
BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
Tuple7 b_origin = CalculateBThreadOriginDataIndex())
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
{
static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
@@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
Number<KRepeat>{},
I1,
I1,
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<KPack / A_KRow * MRepeat>{},
I0,
I0,
I0,
I1));
static constexpr auto b_thread_desc_ =
@@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
Number<KRepeat>{},
I1,
I1,
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<KPack / B_KRow * NRepeat>{},
I0,
I0,
I0,
I1));
// C[M, N, NumRegWmma]
@@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
A_K1,
A_K1>;
@@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
B_K1,
B_K1>;

View File

@@ -32,6 +32,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_v1
{
@@ -55,6 +56,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
@@ -75,6 +77,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -94,6 +97,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
@@ -114,10 +118,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>;
using Base::I0;
using Base::I1;
using Base::WaveSize;
using typename Base::HotLoopInstList;
using Base::A_K1;
@@ -187,6 +191,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
index_t num_loop,
index_t num_loop_per_scale) const
{
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
@@ -211,27 +217,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
auto blockwise_gemm_func = [&]() {
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0),
a_thread_buf);
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, I0, I0, I0, I0, I0, I0),
a_thread_buf);
if constexpr(m0 == I0)
{
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(
Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
b_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0, I0),
b_thread_buf);
});
}
else
@@ -239,45 +241,60 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(
Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, I0, I0, I0, I0),
make_tuple(I0, n0, I0, I0, I0, I0, I0),
b_thread_buf);
});
}
}
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, I0, I0, I0, I0, Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
I0,
I0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
I0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, I0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
@@ -324,8 +341,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
});
}
static_for<0, NRepeat, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
static_for<0, KInner, 1>{}([&](auto) {
static_for<0, NRepeat, 1>{}([&](auto) {
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
});
});
});
});
@@ -348,20 +367,20 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
protected:
// A[MRepeat, I1, I1, KPack]
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, Number<A_K1>{}));
make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, I1, Number<A_K1>{}));
// B[NRepeat, N1, N2, KPack]
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}));
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, I1, Number<B_K1>{}));
using AThreadCopy =
ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
A_K1,
A_K1>;
@@ -370,9 +389,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
B_K1,
B_K1>;
@@ -399,6 +418,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
BlockSize,
@@ -419,6 +439,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -438,6 +459,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
@@ -458,6 +480,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>;
using Base::I0;
using Base::I1;
@@ -532,6 +555,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
index_t num_loop,
index_t num_loop_per_scale) const
{
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
@@ -557,33 +582,22 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
m0,
I0,
I0,
I0,
I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0_inner, I0, I0, I0),
a_thread_buf);
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0_inner, I0, I0, I0, I0),
a_thread_buf);
});
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
n0,
I0,
I0,
I0,
I0),
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0_inner, I0, I0, I0),
make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
b_thread_buf);
});
}
@@ -592,18 +606,13 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
n0,
I0,
I0,
I0,
I0),
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(I0)[Number<
n0 * BScaleStruct::num_scale_k_block +
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0_inner, I0, I0, I0),
make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
b_thread_buf);
});
}
@@ -622,62 +631,69 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<ik / A_K1>{},
m0,
k0_inner,
I0,
I0,
Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0_inner,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0_inner,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard.
// B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts.
// It is performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
m0 == MRepeat - 1 && n0 == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<ik / B_K1>{},
n0,
k0_inner,
I0,
I0,
Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard.
// B) reduce VMEM FIFO congestion by applying small delays to
// different wavefronts.
// It is performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
n0 == NRepeat - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
@@ -729,12 +745,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
Number<KRepeatPerCluster>{},
I1,
I1,
I1,
Number<A_K1>{}),
make_tuple(Number<A_K1>{},
Number<KPack / A_KRow>{},
Number<KPack / A_KRow * MRepeat>{},
I0,
I0,
I0,
I1));
static constexpr auto b_thread_desc_ =
@@ -743,12 +761,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
Number<KRepeatPerCluster>{},
I1,
I1,
I1,
Number<B_K1>{}),
make_tuple(Number<B_K1>{},
Number<KPack / B_KRow>{},
Number<KPack / B_KRow * NRepeat>{},
I0,
I0,
I0,
I1));
using AThreadCopy =
@@ -756,9 +776,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
ComputeTypeA,
decltype(a_block_desc_k0_m0_m1_m2_k1),
decltype(a_thread_desc_),
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
A_K1,
A_K1>;
@@ -767,9 +787,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
ComputeTypeB,
decltype(b_block_desc_k0_n0_n1_n2_k1),
decltype(b_thread_desc_),
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5>,
5,
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
Sequence<0, 1, 2, 3, 4, 5, 6>,
6,
B_K1,
B_K1>;

View File

@@ -32,6 +32,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC = false>
struct BlockwiseGemmWmmaops_pipeline_v3
{
@@ -55,6 +56,7 @@ template <index_t BlockSize,
index_t MRepeat,
index_t NRepeat,
index_t KPack,
index_t KInner,
bool TransposeC>
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
BlockSize,
@@ -75,6 +77,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
ADataType,
@@ -94,6 +97,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>
{
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
@@ -114,6 +118,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
MRepeat,
NRepeat,
KPack,
KInner,
TransposeC>;
using Base::I0;
@@ -290,40 +295,37 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
{
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(
a_block_desc_k0_m0_m1_m2_k1,
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0),
a_thread_buf);
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_block_buf,
a_thread_desc_,
make_tuple(I0, m0, k0, I0, I0, I0, I0),
a_thread_buf);
});
if constexpr(ck::is_same_v<BScaleStruct, Empty>)
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
}
else
{
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(
b_block_desc_k0_n0_n1_n2_k1,
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0),
b_thread_buf);
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_block_buf,
b_scale_struct.b_scale_thread_bufs(
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
k0 / BScaleStruct::num_scale_krepeat>{}],
b_thread_desc_,
make_tuple(I0, n0, k0, I0, I0, I0, I0),
b_thread_buf);
});
}
});
@@ -364,6 +366,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
index_t num_loop_per_scale) const
{
__builtin_amdgcn_sched_barrier(0);
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
@@ -424,41 +429,48 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<ik / A_K1>{},
m0,
k0,
I0,
I0,
Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(
a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<ik / B_K1>{},
n0,
k0,
I0,
I0,
Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
@@ -489,31 +501,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});
@@ -531,31 +559,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
static_for<0, KInner, 1>{}([&](auto k_inner) {
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(Number<kk / A_K1>{},
m0,
k0,
I0,
I0,
I0,
Number<kk % A_K1>{}))>{}];
});
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(Number<kk / B_K1>{},
n0,
k0,
I0,
I0,
I0,
Number<kk % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
});
using wmma_input_type_a =
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
using wmma_input_type_b =
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
b_thread_vec.template AsType<wmma_input_type_b>(),
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
});
});
});

View File

@@ -17,6 +17,9 @@ template <typename ABLayout,
index_t KPerBlock,
index_t MNPerWmma,
index_t ABK1Value,
index_t KPack,
index_t KInner,
index_t KPerWmmaBlk,
bool UseBlockPaddingAB,
bool PermuteAB,
typename ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
@@ -374,14 +377,93 @@ struct ABTransferThreadTiles
#else
constexpr auto KRow = I1;
#endif
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow)),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
if constexpr(KInner > 1)
{
// KPack = KInner * KPerWmma
// K1 = KInner * KPerWmmaBlk
// Each thread loads multiple tiles with one instruction
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(
make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow, Number<1>{})),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
}
else
{
// KPack = KPerWmma (KInner == 1)
if constexpr(ABK1 <= KPerWmmaBlk)
{
// K1 <= single tile (KPerWmmaBlk)
// Each thread will load KPerWmmaBlk for the WMMA instruction
// Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1
// (rest of the single WMMA tile for single thread) and then over KRow
// (rest of the single WMMA tile for single wave)
// KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(
make_unmerge_transform(make_tuple(
Number<ABK0 / (KPack / ABK1)>{}, KRow, Number<KPack / KRow / ABK1>{})),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<ABK1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
}
else
{
// K1 > single tile (KPerWmmaBlk)
// Each thread will load KPerWmmaBlk for the WMMA instruction
// Since K1 > single tile, each thread loads KPerWmmaBlk and the next
// KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout).
// This layout is needed to support for example AK1 > single tile and
// BK1 <= single tile in the same gemm
// KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma -
// K1
constexpr auto desc1 = transform_tensor_descriptor(
BlockDesc{},
make_tuple(
make_pass_through_transform(Number<ABK0>{}),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_unmerge_transform(make_tuple(Number<ABK1 / KPack>{},
Number<KPack / KPerWmmaBlk / KRow>{},
Number<KRow>{},
Number<KPerWmmaBlk>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{}));
return transform_tensor_descriptor(
desc1,
make_tuple(
make_pass_through_transform(Number<KPack / KPerWmmaBlk / KRow>{}),
make_pass_through_transform(Number<MNRepeat>{}),
make_merge_transform(make_tuple(Number<ABK0>{}, Number<ABK1 / KPack>{})),
make_pass_through_transform(Number<MNWaves>{}),
make_pass_through_transform(Number<KRow>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_pass_through_transform(Number<KPerWmmaBlk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}));
}
}
}
__device__ static constexpr auto GetBlockStep()

View File

@@ -313,14 +313,16 @@ struct ABTransferWaveTiles
// This is a block descriptor used to read LDS memory into register
// It's defined in a way consistent with the existing implementation to
// avoid changes in the pipelines
return make_naive_tensor_descriptor(make_tuple(Number<KPerBlock / KPack>{},
return make_naive_tensor_descriptor(make_tuple(I1,
Number<MNRepeat>{},
Number<KPerBlock / KPack>{},
Number<MNWaves>{},
Number<MNKRow>{},
Number<MNPerWmma>{},
Number<ABK1Value>{}),
make_tuple(Number<KPack * MNPerWmma>{},
make_tuple(I0,
Number<KPerBlock * MNPerWmma * MNWaves>{},
Number<KPack * MNPerWmma>{},
Number<KPerBlock * MNPerWmma>{},
Number<MNPerWmma * ABK1Value>{},
Number<ABK1Value>{},

View File

@@ -109,9 +109,20 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
// TODO: I am pretty sure this is always 16 and *should* always be 16.
static constexpr auto KPack =
math::integer_least_multiple(math::integer_least_multiple(AK1Value, BK1Value), 16);
static constexpr index_t KPerWmmaBlk =
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
.k_per_blk;
static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk);
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
static constexpr index_t KPack =
KInner *
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
.k_per_wmma;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -201,54 +212,115 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
return b1_block_copy_step;
}
template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
{
// K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1
constexpr auto K0 = BlockDesc{}.GetLength(I0);
constexpr auto K1 = BlockDesc{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto KRow = I2;
#else
constexpr auto KRow = I1;
#endif
if constexpr(KInner > 1)
{
// KPack = KInner * KPerWmma
// K1 = KInner * KPerWmmaBlk
// Each thread loads multiple tiles with one instruction
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(
make_unmerge_transform(make_tuple(Number<K0 / (KRow)>{}, KRow, Number<1>{})),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
}
else
{
// KPack = KPerWmma (KInner == 1)
if constexpr(K1 <= KPerWmmaBlk)
{
// K1 <= single tile (KPerWmmaBlk)
// Each thread will load KPerWmmaBlk for the WMMA instruction
// Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1
// (rest of the single WMMA tile for single thread) and then over KRow
// (rest of the single WMMA tile for single wave)
// KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
return transform_tensor_descriptor(
BlockDesc{},
make_tuple(make_unmerge_transform(make_tuple(
Number<K0 / (KPack / K1)>{}, KRow, Number<KPack / KRow / K1>{})),
make_unmerge_transform(make_tuple(
Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_pass_through_transform(Number<K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
}
else
{
// K1 > single tile (KPerWmmaBlk)
// Each thread will load KPerWmmaBlk for the WMMA instruction
// Since K1 > single tile, each thread loads KPerWmmaBlk and the next
// KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout).
// This layout is needed to support for example AK1 > single tile and
// BK1 <= single tile in the same gemm
// KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma -
// K1
constexpr auto desc1 = transform_tensor_descriptor(
BlockDesc{},
make_tuple(
make_pass_through_transform(Number<K0>{}),
make_unmerge_transform(
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
make_unmerge_transform(make_tuple(Number<K1 / KPack>{},
Number<KPack / KPerWmmaBlk / KRow>{},
Number<KRow>{},
Number<KPerWmmaBlk>{}))),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{}));
return transform_tensor_descriptor(
desc1,
make_tuple(make_pass_through_transform(Number<KPack / KPerWmmaBlk / KRow>{}),
make_pass_through_transform(Number<MNRepeat>{}),
make_merge_transform(make_tuple(Number<K0>{}, Number<K1 / KPack>{})),
make_pass_through_transform(Number<MNWaves>{}),
make_pass_through_transform(Number<KRow>{}),
make_pass_through_transform(Number<MNPerWmma>{}),
make_pass_through_transform(Number<KPerWmmaBlk>{})),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2, 3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{},
Sequence<7>{}),
make_tuple(Sequence<0>{},
Sequence<1>{},
Sequence<2>{},
Sequence<3>{},
Sequence<4>{},
Sequence<5>{},
Sequence<6>{}));
}
}
}
template <typename ABlockDesc_>
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
{
constexpr auto a_wave_desc = [&]() {
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto A_KRow = I2;
#else
constexpr auto A_KRow = I1;
#endif
return transform_tensor_descriptor(
ABlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
make_unmerge_transform(
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
make_pass_through_transform(Number<A_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}();
return a_wave_desc;
return MakeWmmaTileDescriptor<MRepeat, MWaves, MPerWmma>(ABlockDesc_{});
}
template <typename B0BlockDesc_>
__host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&)
{
constexpr auto b0_wave_desc = [&]() {
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
#ifdef __gfx12__
constexpr auto B_KRow = I2;
#else
constexpr auto B_KRow = I1;
#endif
return transform_tensor_descriptor(
B0BlockDesc_{},
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
make_unmerge_transform(
make_tuple(Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
make_pass_through_transform(Number<B_K1>{})),
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
}();
return b0_wave_desc;
return MakeWmmaTileDescriptor<LRepeat, LWaves, LPerWmma>(B0BlockDesc_{});
}
template <typename A1BlockDesc_AL0_M_AL1>
@@ -356,6 +428,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
MRepeat,
LRepeat,
KPack,
KInner,
true>())>; // TransposeC (must be true to work), C' = B' x A'
// block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01}

View File

@@ -151,10 +151,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
static constexpr auto AK1Number = Number<AK1Value>{};
static constexpr auto BK1Number = Number<BK1Value>{};
static constexpr index_t KPack = math::max(
math::lcm(AK1Number, BK1Number),
static constexpr index_t KPerWmmaBlk =
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.k_per_wmma);
.k_per_blk;
static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk);
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
static constexpr index_t KPack =
KInner *
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
.k_per_wmma;
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
@@ -218,6 +228,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
KPerBlock,
MPerWmma,
AK1Value,
KPack,
KInner,
KPerWmmaBlk,
UseBlockPaddingA,
PermuteA,
ABlockTransferThreadClusterLengths_AK0_M_AK1,
@@ -251,6 +264,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
KPerBlock,
NPerWmma,
BK1Value,
KPack,
KInner,
KPerWmmaBlk,
UseBlockPaddingB,
PermuteB,
BBlockTransferThreadClusterLengths_BK0_N_BK1,
@@ -563,7 +579,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
NPerWmma,
MRepeat,
NRepeat,
KPack>())>;
KPack,
KInner>())>;
// Used to create obj in global function and pass it to Run method
using EpilogueCShuffle =

View File

@@ -95,6 +95,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
@@ -136,6 +137,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
@@ -173,6 +175,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
@@ -209,6 +212,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 2;
@@ -251,6 +255,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t src_a_data_size = 2;
static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
@@ -301,6 +306,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
// static constexpr index_t acc_data_size = 4;
@@ -339,6 +345,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16_gfx12,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
@@ -372,6 +379,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
// static constexpr index_t src_a_data_size = 2;
// static constexpr index_t src_b_data_size = 2;
static constexpr index_t acc_data_size = 4;
@@ -413,6 +421,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
@@ -448,6 +457,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
@@ -483,6 +493,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
@@ -518,6 +529,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12,
static constexpr index_t m_per_wmma = 16;
static constexpr index_t n_per_wmma = 16;
static constexpr index_t k_per_wmma = 16;
static constexpr index_t k_per_blk = 8;
static constexpr index_t acc_data_size = 4;
static constexpr index_t acc_pack_number = 1;
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
@@ -768,6 +780,8 @@ struct WmmaGemm
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
__device__ static constexpr index_t GetKPerWaveBlk() { return wmma_instr.k_per_blk; }
template <class FloatA, class FloatB, class FloatC>
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
{