mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Extend support for ak1 / bk1 WMMA (#3073)
* Extend AK1 / BK1 support: - Add support for AK1 != BK1 - Add support for AK1, BK1 > 8 - Introduce KInner template parameter for pipelines when loading multiple tiles with one instruction * fix clang format
This commit is contained in:
@@ -28,6 +28,7 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
@@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
@@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>{};
|
||||
}
|
||||
else
|
||||
|
||||
@@ -30,6 +30,7 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_base
|
||||
{
|
||||
@@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
static constexpr index_t B_KRow = 1;
|
||||
#endif
|
||||
|
||||
static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
|
||||
static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
|
||||
static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
AccDataType,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
KPack / KInner,
|
||||
TransposeC>{};
|
||||
|
||||
static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
|
||||
static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
|
||||
static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);
|
||||
|
||||
static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
|
||||
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
|
||||
|
||||
static constexpr auto wmma_gemm =
|
||||
WmmaGemm<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
|
||||
|
||||
static constexpr index_t KRepeat = KPerBlock / KPack;
|
||||
|
||||
static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
|
||||
@@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
const auto wmma_krow = 0;
|
||||
#endif
|
||||
|
||||
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
|
||||
return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
|
||||
return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
@@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
const auto wmma_krow = 0;
|
||||
#endif
|
||||
|
||||
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
|
||||
return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
|
||||
return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
@@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
|
||||
using Tuple7 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
/**
|
||||
* @brief Constructor for BlockwiseGemmWmmaops_pipeline_base.
|
||||
@@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
* repeat dimensions.
|
||||
*/
|
||||
__host__ __device__
|
||||
BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
|
||||
BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple7 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
|
||||
@@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
Number<KRepeat>{},
|
||||
I1,
|
||||
I1,
|
||||
I1,
|
||||
Number<A_K1>{}),
|
||||
make_tuple(Number<A_K1>{},
|
||||
Number<KPack / A_KRow>{},
|
||||
Number<KPack / A_KRow * MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1));
|
||||
|
||||
static constexpr auto b_thread_desc_ =
|
||||
@@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
Number<KRepeat>{},
|
||||
I1,
|
||||
I1,
|
||||
I1,
|
||||
Number<B_K1>{}),
|
||||
make_tuple(Number<B_K1>{},
|
||||
Number<KPack / B_KRow>{},
|
||||
Number<KPack / B_KRow * NRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1));
|
||||
|
||||
// C[M, N, NumRegWmma]
|
||||
@@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
@@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1
|
||||
{
|
||||
@@ -55,6 +56,7 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
@@ -75,6 +77,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -94,6 +97,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
@@ -114,10 +118,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::WaveSize;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::A_K1;
|
||||
@@ -187,6 +191,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
@@ -211,27 +217,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
auto blockwise_gemm_func = [&]() {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
if constexpr(m0 == I0)
|
||||
{
|
||||
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(
|
||||
Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -239,45 +241,60 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(
|
||||
Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0),
|
||||
make_tuple(I0, n0, I0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, I0, I0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, I0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -324,8 +341,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
}
|
||||
static_for<0, NRepeat, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
|
||||
static_for<0, KInner, 1>{}([&](auto) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -348,20 +367,20 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
protected:
|
||||
// A[MRepeat, I1, I1, KPack]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, Number<A_K1>{}));
|
||||
make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, I1, Number<A_K1>{}));
|
||||
|
||||
// B[NRepeat, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}));
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, I1, Number<B_K1>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
@@ -370,9 +389,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
@@ -399,6 +418,7 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
BlockSize,
|
||||
@@ -419,6 +439,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -438,6 +459,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
@@ -458,6 +480,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
@@ -532,6 +555,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
@@ -557,33 +582,22 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
|
||||
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
|
||||
m0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0_inner, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0_inner, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
|
||||
n0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0),
|
||||
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0_inner, I0, I0, I0),
|
||||
make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
@@ -592,18 +606,13 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
|
||||
n0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0),
|
||||
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(I0)[Number<
|
||||
n0 * BScaleStruct::num_scale_k_block +
|
||||
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0_inner, I0, I0, I0),
|
||||
make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
@@ -622,62 +631,69 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<ik / A_K1>{},
|
||||
m0,
|
||||
k0_inner,
|
||||
I0,
|
||||
I0,
|
||||
Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k0_inner,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k0_inner,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard.
|
||||
// B) reduce VMEM FIFO congestion by applying small delays to
|
||||
// different wavefronts.
|
||||
// It is performed near the end of MAC cluster to minimize lgkmcnt
|
||||
// penalty
|
||||
if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
|
||||
m0 == MRepeat - 1 && n0 == NRepeat - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<ik / B_K1>{},
|
||||
n0,
|
||||
k0_inner,
|
||||
I0,
|
||||
I0,
|
||||
Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard.
|
||||
// B) reduce VMEM FIFO congestion by applying small delays to
|
||||
// different wavefronts.
|
||||
// It is performed near the end of MAC cluster to minimize lgkmcnt
|
||||
// penalty
|
||||
if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
|
||||
n0 == NRepeat - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -729,12 +745,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
Number<KRepeatPerCluster>{},
|
||||
I1,
|
||||
I1,
|
||||
I1,
|
||||
Number<A_K1>{}),
|
||||
make_tuple(Number<A_K1>{},
|
||||
Number<KPack / A_KRow>{},
|
||||
Number<KPack / A_KRow * MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1));
|
||||
|
||||
static constexpr auto b_thread_desc_ =
|
||||
@@ -743,12 +761,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
Number<KRepeatPerCluster>{},
|
||||
I1,
|
||||
I1,
|
||||
I1,
|
||||
Number<B_K1>{}),
|
||||
make_tuple(Number<B_K1>{},
|
||||
Number<KPack / B_KRow>{},
|
||||
Number<KPack / B_KRow * NRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1));
|
||||
|
||||
using AThreadCopy =
|
||||
@@ -756,9 +776,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
@@ -767,9 +787,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3
|
||||
{
|
||||
@@ -55,6 +56,7 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
@@ -75,6 +77,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -94,6 +97,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
@@ -114,6 +118,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
|
||||
@@ -290,40 +295,37 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
|
||||
if constexpr(ck::is_same_v<BScaleStruct, Empty>)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -364,6 +366,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
@@ -424,41 +429,48 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<ik / A_K1>{},
|
||||
m0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<ik / B_K1>{},
|
||||
n0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -489,31 +501,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -531,31 +559,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,6 +17,9 @@ template <typename ABLayout,
|
||||
index_t KPerBlock,
|
||||
index_t MNPerWmma,
|
||||
index_t ABK1Value,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
index_t KPerWmmaBlk,
|
||||
bool UseBlockPaddingAB,
|
||||
bool PermuteAB,
|
||||
typename ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
|
||||
@@ -374,14 +377,93 @@ struct ABTransferThreadTiles
|
||||
#else
|
||||
constexpr auto KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
if constexpr(KInner > 1)
|
||||
{
|
||||
// KPack = KInner * KPerWmma
|
||||
// K1 = KInner * KPerWmmaBlk
|
||||
// Each thread loads multiple tiles with one instruction
|
||||
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow, Number<1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KPack = KPerWmma (KInner == 1)
|
||||
if constexpr(ABK1 <= KPerWmmaBlk)
|
||||
{
|
||||
// K1 <= single tile (KPerWmmaBlk)
|
||||
// Each thread will load KPerWmmaBlk for the WMMA instruction
|
||||
// Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1
|
||||
// (rest of the single WMMA tile for single thread) and then over KRow
|
||||
// (rest of the single WMMA tile for single wave)
|
||||
// KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<ABK0 / (KPack / ABK1)>{}, KRow, Number<KPack / KRow / ABK1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// K1 > single tile (KPerWmmaBlk)
|
||||
// Each thread will load KPerWmmaBlk for the WMMA instruction
|
||||
// Since K1 > single tile, each thread loads KPerWmmaBlk and the next
|
||||
// KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout).
|
||||
// This layout is needed to support for example AK1 > single tile and
|
||||
// BK1 <= single tile in the same gemm
|
||||
// KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma -
|
||||
// K1
|
||||
constexpr auto desc1 = transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<ABK0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_unmerge_transform(make_tuple(Number<ABK1 / KPack>{},
|
||||
Number<KPack / KPerWmmaBlk / KRow>{},
|
||||
Number<KRow>{},
|
||||
Number<KPerWmmaBlk>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc1,
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KPack / KPerWmmaBlk / KRow>{}),
|
||||
make_pass_through_transform(Number<MNRepeat>{}),
|
||||
make_merge_transform(make_tuple(Number<ABK0>{}, Number<ABK1 / KPack>{})),
|
||||
make_pass_through_transform(Number<MNWaves>{}),
|
||||
make_pass_through_transform(Number<KRow>{}),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_pass_through_transform(Number<KPerWmmaBlk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2, 3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetBlockStep()
|
||||
|
||||
@@ -313,14 +313,16 @@ struct ABTransferWaveTiles
|
||||
// This is a block descriptor used to read LDS memory into register
|
||||
// It's defined in a way consistent with the existing implementation to
|
||||
// avoid changes in the pipelines
|
||||
return make_naive_tensor_descriptor(make_tuple(Number<KPerBlock / KPack>{},
|
||||
return make_naive_tensor_descriptor(make_tuple(I1,
|
||||
Number<MNRepeat>{},
|
||||
Number<KPerBlock / KPack>{},
|
||||
Number<MNWaves>{},
|
||||
Number<MNKRow>{},
|
||||
Number<MNPerWmma>{},
|
||||
Number<ABK1Value>{}),
|
||||
make_tuple(Number<KPack * MNPerWmma>{},
|
||||
make_tuple(I0,
|
||||
Number<KPerBlock * MNPerWmma * MNWaves>{},
|
||||
Number<KPack * MNPerWmma>{},
|
||||
Number<KPerBlock * MNPerWmma>{},
|
||||
Number<MNPerWmma * ABK1Value>{},
|
||||
Number<ABK1Value>{},
|
||||
|
||||
@@ -109,9 +109,20 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
// TODO: I am pretty sure this is always 16 and *should* always be 16.
|
||||
static constexpr auto KPack =
|
||||
math::integer_least_multiple(math::integer_least_multiple(AK1Value, BK1Value), 16);
|
||||
static constexpr index_t KPerWmmaBlk =
|
||||
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
|
||||
.k_per_blk;
|
||||
|
||||
static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
|
||||
|
||||
static constexpr index_t KPack =
|
||||
KInner *
|
||||
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
|
||||
.k_per_wmma;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -201,54 +212,115 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
return b1_block_copy_step;
|
||||
}
|
||||
|
||||
template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
|
||||
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
|
||||
{
|
||||
// K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1
|
||||
constexpr auto K0 = BlockDesc{}.GetLength(I0);
|
||||
constexpr auto K1 = BlockDesc{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto KRow = I2;
|
||||
#else
|
||||
constexpr auto KRow = I1;
|
||||
#endif
|
||||
|
||||
if constexpr(KInner > 1)
|
||||
{
|
||||
// KPack = KInner * KPerWmma
|
||||
// K1 = KInner * KPerWmmaBlk
|
||||
// Each thread loads multiple tiles with one instruction
|
||||
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(Number<K0 / (KRow)>{}, KRow, Number<1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KPack = KPerWmma (KInner == 1)
|
||||
if constexpr(K1 <= KPerWmmaBlk)
|
||||
{
|
||||
// K1 <= single tile (KPerWmmaBlk)
|
||||
// Each thread will load KPerWmmaBlk for the WMMA instruction
|
||||
// Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1
|
||||
// (rest of the single WMMA tile for single thread) and then over KRow
|
||||
// (rest of the single WMMA tile for single wave)
|
||||
// KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<K0 / (KPack / K1)>{}, KRow, Number<KPack / KRow / K1>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// K1 > single tile (KPerWmmaBlk)
|
||||
// Each thread will load KPerWmmaBlk for the WMMA instruction
|
||||
// Since K1 > single tile, each thread loads KPerWmmaBlk and the next
|
||||
// KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout).
|
||||
// This layout is needed to support for example AK1 > single tile and
|
||||
// BK1 <= single tile in the same gemm
|
||||
// KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma -
|
||||
// K1
|
||||
constexpr auto desc1 = transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_unmerge_transform(make_tuple(Number<K1 / KPack>{},
|
||||
Number<KPack / KPerWmmaBlk / KRow>{},
|
||||
Number<KRow>{},
|
||||
Number<KPerWmmaBlk>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc1,
|
||||
make_tuple(make_pass_through_transform(Number<KPack / KPerWmmaBlk / KRow>{}),
|
||||
make_pass_through_transform(Number<MNRepeat>{}),
|
||||
make_merge_transform(make_tuple(Number<K0>{}, Number<K1 / KPack>{})),
|
||||
make_pass_through_transform(Number<MNWaves>{}),
|
||||
make_pass_through_transform(Number<KRow>{}),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_pass_through_transform(Number<KPerWmmaBlk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2, 3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
|
||||
{
|
||||
constexpr auto a_wave_desc = [&]() {
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto A_KRow = I2;
|
||||
#else
|
||||
constexpr auto A_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}();
|
||||
|
||||
return a_wave_desc;
|
||||
return MakeWmmaTileDescriptor<MRepeat, MWaves, MPerWmma>(ABlockDesc_{});
|
||||
}
|
||||
|
||||
template <typename B0BlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&)
|
||||
{
|
||||
constexpr auto b0_wave_desc = [&]() {
|
||||
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1
|
||||
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
B0BlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}();
|
||||
|
||||
return b0_wave_desc;
|
||||
return MakeWmmaTileDescriptor<LRepeat, LWaves, LPerWmma>(B0BlockDesc_{});
|
||||
}
|
||||
|
||||
template <typename A1BlockDesc_AL0_M_AL1>
|
||||
@@ -356,6 +428,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
MRepeat,
|
||||
LRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
true>())>; // TransposeC (must be true to work), C' = B' x A'
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01}
|
||||
|
||||
@@ -151,10 +151,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1Number, BK1Number),
|
||||
static constexpr index_t KPerWmmaBlk =
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.k_per_wmma);
|
||||
.k_per_blk;
|
||||
|
||||
static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
|
||||
|
||||
static constexpr index_t KPack =
|
||||
KInner *
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.k_per_wmma;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -218,6 +228,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
AK1Value,
|
||||
KPack,
|
||||
KInner,
|
||||
KPerWmmaBlk,
|
||||
UseBlockPaddingA,
|
||||
PermuteA,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
@@ -251,6 +264,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
KPerBlock,
|
||||
NPerWmma,
|
||||
BK1Value,
|
||||
KPack,
|
||||
KInner,
|
||||
KPerWmmaBlk,
|
||||
UseBlockPaddingB,
|
||||
PermuteB,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
@@ -563,7 +579,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>())>;
|
||||
KPack,
|
||||
KInner>())>;
|
||||
|
||||
// Used to create obj in global function and pass it to Run method
|
||||
using EpilogueCShuffle =
|
||||
|
||||
@@ -95,6 +95,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -136,6 +137,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -173,6 +175,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
@@ -209,6 +212,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
@@ -251,6 +255,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -301,6 +306,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
// static constexpr index_t acc_data_size = 4;
|
||||
@@ -339,6 +345,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -372,6 +379,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -413,6 +421,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
@@ -448,6 +457,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
@@ -483,6 +493,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
@@ -518,6 +529,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
@@ -768,6 +780,8 @@ struct WmmaGemm
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
|
||||
|
||||
__device__ static constexpr index_t GetKPerWaveBlk() { return wmma_instr.k_per_blk; }
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user