mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 04:49:54 +00:00
Merge commit '1b1c46e508c1fd40a03f54114b6b78629032fb4f' into develop
This commit is contained in:
@@ -28,6 +28,7 @@ template <BlockGemmPipelineVersion BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC = false>
|
||||
constexpr auto BlockGemmPipeline_Selector()
|
||||
{
|
||||
@@ -52,6 +53,7 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>{};
|
||||
}
|
||||
else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3)
|
||||
@@ -75,6 +77,7 @@ constexpr auto BlockGemmPipeline_Selector()
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>{};
|
||||
}
|
||||
else
|
||||
|
||||
@@ -30,6 +30,7 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_base
|
||||
{
|
||||
@@ -38,6 +39,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
static constexpr auto I6 = Number<6>{};
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -54,15 +56,20 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
static constexpr index_t B_KRow = 1;
|
||||
#endif
|
||||
|
||||
static constexpr index_t A_K1 = AWmmaTileDesc{}.GetLength(I5);
|
||||
static constexpr index_t B_K1 = BWmmaTileDesc{}.GetLength(I5);
|
||||
static constexpr auto wmma_gemm = WmmaGemm<ComputeTypeA,
|
||||
ComputeTypeB,
|
||||
AccDataType,
|
||||
MPerWmma,
|
||||
NPerWmma,
|
||||
KPack / KInner,
|
||||
TransposeC>{};
|
||||
|
||||
static constexpr index_t KPerThread = wmma_gemm.wmma_instr.k_per_blk * KInner;
|
||||
static constexpr index_t A_K1 = ck::math::min(AWmmaTileDesc{}.GetLength(I6), KPerThread);
|
||||
static constexpr index_t B_K1 = ck::math::min(BWmmaTileDesc{}.GetLength(I6), KPerThread);
|
||||
|
||||
static_assert(KPack % (A_K1 * A_KRow) == 0, "wrong!");
|
||||
static_assert(KPack % (B_K1 * B_KRow) == 0, "wrong!");
|
||||
|
||||
static constexpr auto wmma_gemm =
|
||||
WmmaGemm<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma, KPack, TransposeC>{};
|
||||
|
||||
static constexpr index_t KRepeat = KPerBlock / KPack;
|
||||
|
||||
static constexpr auto WmmaK = Number<wmma_gemm.wmma_instr.k_per_wmma>{};
|
||||
@@ -191,8 +198,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
const auto wmma_krow = 0;
|
||||
#endif
|
||||
|
||||
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
|
||||
return make_tuple(0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
|
||||
return make_tuple(0, 0, 0, waveId_m, wmma_krow, wmma_a_idx, 0);
|
||||
}
|
||||
|
||||
__device__ static auto CalculateBThreadOriginDataIndex()
|
||||
@@ -209,8 +215,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
const auto wmma_krow = 0;
|
||||
#endif
|
||||
|
||||
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
|
||||
return make_tuple(0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
|
||||
return make_tuple(0, 0, 0, waveId_n, wmma_krow, wmma_b_idx, 0);
|
||||
}
|
||||
|
||||
template <index_t m0, index_t n0>
|
||||
@@ -241,7 +246,7 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
return make_tuple(c_thread_m, c_thread_n);
|
||||
}
|
||||
|
||||
using Tuple6 = decltype(CalculateAThreadOriginDataIndex());
|
||||
using Tuple7 = decltype(CalculateAThreadOriginDataIndex());
|
||||
|
||||
/**
|
||||
* @brief Constructor for BlockwiseGemmWmmaops_pipeline_base.
|
||||
@@ -261,8 +266,8 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
* repeat dimensions.
|
||||
*/
|
||||
__host__ __device__
|
||||
BlockwiseGemmWmmaops_pipeline_base(Tuple6 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple6 b_origin = CalculateBThreadOriginDataIndex())
|
||||
BlockwiseGemmWmmaops_pipeline_base(Tuple7 a_origin = CalculateAThreadOriginDataIndex(),
|
||||
Tuple7 b_origin = CalculateBThreadOriginDataIndex())
|
||||
: a_thread_copy_(a_origin), b_thread_copy_(b_origin)
|
||||
{
|
||||
static_assert(AWmmaTileDesc::IsKnownAtCompileTime() &&
|
||||
@@ -343,12 +348,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
Number<KRepeat>{},
|
||||
I1,
|
||||
I1,
|
||||
I1,
|
||||
Number<A_K1>{}),
|
||||
make_tuple(Number<A_K1>{},
|
||||
Number<KPack / A_KRow>{},
|
||||
Number<KPack / A_KRow * MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1));
|
||||
|
||||
static constexpr auto b_thread_desc_ =
|
||||
@@ -357,12 +364,14 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
Number<KRepeat>{},
|
||||
I1,
|
||||
I1,
|
||||
I1,
|
||||
Number<B_K1>{}),
|
||||
make_tuple(Number<B_K1>{},
|
||||
Number<KPack / B_KRow>{},
|
||||
Number<KPack / B_KRow * NRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1));
|
||||
|
||||
// C[M, N, NumRegWmma]
|
||||
@@ -374,9 +383,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
@@ -385,9 +394,9 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1
|
||||
{
|
||||
@@ -55,6 +56,7 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
@@ -75,6 +77,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -94,6 +97,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
@@ -114,10 +118,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::WaveSize;
|
||||
using typename Base::HotLoopInstList;
|
||||
|
||||
using Base::A_K1;
|
||||
@@ -187,6 +191,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
@@ -211,27 +217,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
auto blockwise_gemm_func = [&]() {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
if constexpr(m0 == I0)
|
||||
{
|
||||
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(
|
||||
Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
else
|
||||
@@ -239,45 +241,60 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(
|
||||
Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, I0, I0, I0, I0),
|
||||
make_tuple(I0, n0, I0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, I0, I0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, I0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -324,8 +341,10 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
});
|
||||
}
|
||||
static_for<0, NRepeat, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
|
||||
static_for<0, KInner, 1>{}([&](auto) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // WMMA
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -348,20 +367,20 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
protected:
|
||||
// A[MRepeat, I1, I1, KPack]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, Number<A_K1>{}));
|
||||
make_tuple(Number<KPack / A_K1 / A_KRow>{}, I1, I1, I1, I1, I1, Number<A_K1>{}));
|
||||
|
||||
// B[NRepeat, N1, N2, KPack]
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, Number<B_K1>{}));
|
||||
static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<KPack / B_K1 / B_KRow>{}, Number<NRepeat>{}, I1, I1, I1, I1, Number<B_K1>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseTensorSliceTransfer_v4<ADataType,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
@@ -370,9 +389,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
@@ -399,6 +418,7 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
BlockSize,
|
||||
@@ -419,6 +439,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -438,6 +459,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
@@ -458,6 +480,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
@@ -532,6 +555,8 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
@@ -557,33 +582,22 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, KRepeat, KRepeatPerCluster>{}([&](auto k0_offset) {
|
||||
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<(k0_offset + k0_inner) * KPack / A_K1 / A_KRow>{},
|
||||
m0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0_inner, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0_offset + k0_inner, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0_inner, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
if constexpr(ck::is_same<BScaleStruct, Empty>::value == true)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
|
||||
n0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0),
|
||||
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0_inner, I0, I0, I0),
|
||||
make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
@@ -592,18 +606,13 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<(k0_offset + k0_inner) * KPack / B_K1 / B_KRow>{},
|
||||
n0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I0),
|
||||
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(I0)[Number<
|
||||
n0 * BScaleStruct::num_scale_k_block +
|
||||
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0_inner, I0, I0, I0),
|
||||
make_tuple(I0, n0, k0_inner, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
@@ -622,62 +631,69 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
static_for<0, KRepeatPerCluster, 1>{}([&](auto k0_inner) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<ik / A_K1>{},
|
||||
m0,
|
||||
k0_inner,
|
||||
I0,
|
||||
I0,
|
||||
Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k0_inner,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k0_inner,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard.
|
||||
// B) reduce VMEM FIFO congestion by applying small delays to
|
||||
// different wavefronts.
|
||||
// It is performed near the end of MAC cluster to minimize lgkmcnt
|
||||
// penalty
|
||||
if constexpr(k0_offset + k0_inner == KRepeat - 1 &&
|
||||
m0 == MRepeat - 1 && n0 == NRepeat - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<ik / B_K1>{},
|
||||
n0,
|
||||
k0_inner,
|
||||
I0,
|
||||
I0,
|
||||
Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
// The block_sync_lds() here performs double duty:
|
||||
// A) safeguard against data hazard.
|
||||
// B) reduce VMEM FIFO congestion by applying small delays to
|
||||
// different wavefronts.
|
||||
// It is performed near the end of MAC cluster to minimize lgkmcnt
|
||||
// penalty
|
||||
if constexpr(k0_offset + k0_inner == KRepeat - 1 && m0 == MRepeat - 1 &&
|
||||
n0 == NRepeat - 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
block_sync_lds();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
if constexpr(k0_inner == 0 && m0 == 0 && n0 == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -729,12 +745,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
Number<KRepeatPerCluster>{},
|
||||
I1,
|
||||
I1,
|
||||
I1,
|
||||
Number<A_K1>{}),
|
||||
make_tuple(Number<A_K1>{},
|
||||
Number<KPack / A_KRow>{},
|
||||
Number<KPack / A_KRow * MRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1));
|
||||
|
||||
static constexpr auto b_thread_desc_ =
|
||||
@@ -743,12 +761,14 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
Number<KRepeatPerCluster>{},
|
||||
I1,
|
||||
I1,
|
||||
I1,
|
||||
Number<B_K1>{}),
|
||||
make_tuple(Number<B_K1>{},
|
||||
Number<KPack / B_KRow>{},
|
||||
Number<KPack / B_KRow * NRepeat>{},
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
I1));
|
||||
|
||||
using AThreadCopy =
|
||||
@@ -756,9 +776,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
ComputeTypeA,
|
||||
decltype(a_block_desc_k0_m0_m1_m2_k1),
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / A_K1 / A_KRow, 1, 1, 1, 1, 1, A_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
A_K1,
|
||||
A_K1>;
|
||||
|
||||
@@ -767,9 +787,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
ComputeTypeB,
|
||||
decltype(b_block_desc_k0_n0_n1_n2_k1),
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5>,
|
||||
5,
|
||||
Sequence<KPack / B_K1 / B_KRow, 1, 1, 1, 1, 1, B_K1>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6>,
|
||||
6,
|
||||
B_K1,
|
||||
B_K1>;
|
||||
|
||||
|
||||
@@ -32,6 +32,7 @@ template <BlockGemmPipelineScheduler BlkGemmPipelineVer,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC = false>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3
|
||||
{
|
||||
@@ -55,6 +56,7 @@ template <index_t BlockSize,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
bool TransposeC>
|
||||
struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockSize,
|
||||
@@ -75,6 +77,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
: BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
ADataType,
|
||||
@@ -94,6 +97,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>
|
||||
{
|
||||
using Base = BlockwiseGemmWmmaops_pipeline_base<BlockSize,
|
||||
@@ -114,6 +118,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
|
||||
@@ -290,40 +295,37 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
{
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(
|
||||
a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(Number<k0 * KPack / A_K1 / A_KRow>{}, m0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
|
||||
if constexpr(ck::is_same_v<BScaleStruct, Empty>)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(Number<k0 * KPack / B_K1 / B_KRow>{}, n0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
}
|
||||
});
|
||||
@@ -364,6 +366,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
@@ -424,41 +429,48 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<ik / A_K1>{},
|
||||
m0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<ik / B_K1>{},
|
||||
n0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -489,31 +501,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -531,31 +559,47 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / A_K1>{}, m0, k0, I0, I0, Number<ik % A_K1>{}))>{}];
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k0,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
static_for<0, KPack / B_KRow, 1>{}([&](auto ik) {
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(make_tuple(
|
||||
Number<ik / B_K1>{}, n0, k0, I0, I0, Number<ik % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, I0));
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -17,6 +17,9 @@ template <typename ABLayout,
|
||||
index_t KPerBlock,
|
||||
index_t MNPerWmma,
|
||||
index_t ABK1Value,
|
||||
index_t KPack,
|
||||
index_t KInner,
|
||||
index_t KPerWmmaBlk,
|
||||
bool UseBlockPaddingAB,
|
||||
bool PermuteAB,
|
||||
typename ABBlockTransferThreadClusterLengths_ABK0_MN_ABK1,
|
||||
@@ -374,14 +377,93 @@ struct ABTransferThreadTiles
|
||||
#else
|
||||
constexpr auto KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
if constexpr(KInner > 1)
|
||||
{
|
||||
// KPack = KInner * KPerWmma
|
||||
// K1 = KInner * KPerWmmaBlk
|
||||
// Each thread loads multiple tiles with one instruction
|
||||
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(Number<ABK0 / KRow>{}, KRow, Number<1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KPack = KPerWmma (KInner == 1)
|
||||
if constexpr(ABK1 <= KPerWmmaBlk)
|
||||
{
|
||||
// K1 <= single tile (KPerWmmaBlk)
|
||||
// Each thread will load KPerWmmaBlk for the WMMA instruction
|
||||
// Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1
|
||||
// (rest of the single WMMA tile for single thread) and then over KRow
|
||||
// (rest of the single WMMA tile for single wave)
|
||||
// KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<ABK0 / (KPack / ABK1)>{}, KRow, Number<KPack / KRow / ABK1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<ABK1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// K1 > single tile (KPerWmmaBlk)
|
||||
// Each thread will load KPerWmmaBlk for the WMMA instruction
|
||||
// Since K1 > single tile, each thread loads KPerWmmaBlk and the next
|
||||
// KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout).
|
||||
// This layout is needed to support for example AK1 > single tile and
|
||||
// BK1 <= single tile in the same gemm
|
||||
// KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma -
|
||||
// K1
|
||||
constexpr auto desc1 = transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<ABK0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_unmerge_transform(make_tuple(Number<ABK1 / KPack>{},
|
||||
Number<KPack / KPerWmmaBlk / KRow>{},
|
||||
Number<KRow>{},
|
||||
Number<KPerWmmaBlk>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc1,
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<KPack / KPerWmmaBlk / KRow>{}),
|
||||
make_pass_through_transform(Number<MNRepeat>{}),
|
||||
make_merge_transform(make_tuple(Number<ABK0>{}, Number<ABK1 / KPack>{})),
|
||||
make_pass_through_transform(Number<MNWaves>{}),
|
||||
make_pass_through_transform(Number<KRow>{}),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_pass_through_transform(Number<KPerWmmaBlk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2, 3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetBlockStep()
|
||||
|
||||
@@ -313,14 +313,16 @@ struct ABTransferWaveTiles
|
||||
// This is a block descriptor used to read LDS memory into register
|
||||
// It's defined in a way consistent with the existing implementation to
|
||||
// avoid changes in the pipelines
|
||||
return make_naive_tensor_descriptor(make_tuple(Number<KPerBlock / KPack>{},
|
||||
return make_naive_tensor_descriptor(make_tuple(I1,
|
||||
Number<MNRepeat>{},
|
||||
Number<KPerBlock / KPack>{},
|
||||
Number<MNWaves>{},
|
||||
Number<MNKRow>{},
|
||||
Number<MNPerWmma>{},
|
||||
Number<ABK1Value>{}),
|
||||
make_tuple(Number<KPack * MNPerWmma>{},
|
||||
make_tuple(I0,
|
||||
Number<KPerBlock * MNPerWmma * MNWaves>{},
|
||||
Number<KPack * MNPerWmma>{},
|
||||
Number<KPerBlock * MNPerWmma>{},
|
||||
Number<MNPerWmma * ABK1Value>{},
|
||||
Number<ABK1Value>{},
|
||||
|
||||
@@ -109,9 +109,20 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
static constexpr auto LWaves = LPerBlock / (LRepeat * LPerWmma);
|
||||
static constexpr auto NWaves = NPerBlock / (NRepeat * NPerWmma);
|
||||
|
||||
// TODO: I am pretty sure this is always 16 and *should* always be 16.
|
||||
static constexpr auto KPack =
|
||||
math::integer_least_multiple(math::integer_least_multiple(AK1Value, BK1Value), 16);
|
||||
static constexpr index_t KPerWmmaBlk =
|
||||
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
|
||||
.k_per_blk;
|
||||
|
||||
static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
|
||||
|
||||
static constexpr index_t KPack =
|
||||
KInner *
|
||||
WmmaSelector<ADataType, B0DataType, Acc0DataType, MPerWmma, LPerWmma>::selected_wmma
|
||||
.k_per_wmma;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -201,54 +212,115 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
return b1_block_copy_step;
|
||||
}
|
||||
|
||||
template <index_t MNRepeat, index_t MNWaves, index_t MNPerWmma, typename BlockDesc>
|
||||
__host__ __device__ static constexpr auto MakeWmmaTileDescriptor(const BlockDesc&)
|
||||
{
|
||||
// K0_MN_K1 -> K0_MNRepeat_MNWaves_KRow_MNPerWmma_K1
|
||||
constexpr auto K0 = BlockDesc{}.GetLength(I0);
|
||||
constexpr auto K1 = BlockDesc{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto KRow = I2;
|
||||
#else
|
||||
constexpr auto KRow = I1;
|
||||
#endif
|
||||
|
||||
if constexpr(KInner > 1)
|
||||
{
|
||||
// KPack = KInner * KPerWmma
|
||||
// K1 = KInner * KPerWmmaBlk
|
||||
// Each thread loads multiple tiles with one instruction
|
||||
// 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_unmerge_transform(make_tuple(Number<K0 / (KRow)>{}, KRow, Number<1>{})),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// KPack = KPerWmma (KInner == 1)
|
||||
if constexpr(K1 <= KPerWmmaBlk)
|
||||
{
|
||||
// K1 <= single tile (KPerWmmaBlk)
|
||||
// Each thread will load KPerWmmaBlk for the WMMA instruction
|
||||
// Since K1 <= single tile, K0 is unmerged first over KPack / KRow / K1
|
||||
// (rest of the single WMMA tile for single thread) and then over KRow
|
||||
// (rest of the single WMMA tile for single wave)
|
||||
// KPack / KRow / K1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1
|
||||
return transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<K0 / (KPack / K1)>{}, KRow, Number<KPack / KRow / K1>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
// K1 > single tile (KPerWmmaBlk)
|
||||
// Each thread will load KPerWmmaBlk for the WMMA instruction
|
||||
// Since K1 > single tile, each thread loads KPerWmmaBlk and the next
|
||||
// KPerWmmaBlk chunk is loaded by a different thread in the same wave (WMMA layout).
|
||||
// This layout is needed to support for example AK1 > single tile and
|
||||
// BK1 <= single tile in the same gemm
|
||||
// KPack / KPerWmmaBlk / KRow - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma -
|
||||
// K1
|
||||
constexpr auto desc1 = transform_tensor_descriptor(
|
||||
BlockDesc{},
|
||||
make_tuple(
|
||||
make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MNRepeat>{}, Number<MNWaves>{}, Number<MNPerWmma>{})),
|
||||
make_unmerge_transform(make_tuple(Number<K1 / KPack>{},
|
||||
Number<KPack / KPerWmmaBlk / KRow>{},
|
||||
Number<KRow>{},
|
||||
Number<KPerWmmaBlk>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<2>{}, Sequence<1, 4, 6>{}, Sequence<3, 0, 5, 7>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
desc1,
|
||||
make_tuple(make_pass_through_transform(Number<KPack / KPerWmmaBlk / KRow>{}),
|
||||
make_pass_through_transform(Number<MNRepeat>{}),
|
||||
make_merge_transform(make_tuple(Number<K0>{}, Number<K1 / KPack>{})),
|
||||
make_pass_through_transform(Number<MNWaves>{}),
|
||||
make_pass_through_transform(Number<KRow>{}),
|
||||
make_pass_through_transform(Number<MNPerWmma>{}),
|
||||
make_pass_through_transform(Number<KPerWmmaBlk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2, 3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{},
|
||||
Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{},
|
||||
Sequence<6>{}));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ABlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeAWaveDescriptor(const ABlockDesc_&)
|
||||
{
|
||||
constexpr auto a_wave_desc = [&]() {
|
||||
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
|
||||
constexpr auto A_K0 = ABlockDesc_{}.GetLength(I0);
|
||||
constexpr auto A_K1 = ABlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto A_KRow = I2;
|
||||
#else
|
||||
constexpr auto A_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
ABlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<A_K0 / A_KRow>{}, A_KRow)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<MRepeat>{}, Number<MWaves>{}, Number<MPerWmma>{})),
|
||||
make_pass_through_transform(Number<A_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}();
|
||||
|
||||
return a_wave_desc;
|
||||
return MakeWmmaTileDescriptor<MRepeat, MWaves, MPerWmma>(ABlockDesc_{});
|
||||
}
|
||||
|
||||
template <typename B0BlockDesc_>
|
||||
__host__ __device__ static constexpr auto MakeB0WaveDescriptor(const B0BlockDesc_&)
|
||||
{
|
||||
constexpr auto b0_wave_desc = [&]() {
|
||||
// BK0_L_BK1 -> BK0_LRepeat_Lwaves_BKRow_LPerWmma_BK1
|
||||
constexpr auto B_K0 = B0BlockDesc_{}.GetLength(I0);
|
||||
constexpr auto B_K1 = B0BlockDesc_{}.GetLength(I2);
|
||||
#ifdef __gfx12__
|
||||
constexpr auto B_KRow = I2;
|
||||
#else
|
||||
constexpr auto B_KRow = I1;
|
||||
#endif
|
||||
return transform_tensor_descriptor(
|
||||
B0BlockDesc_{},
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<B_K0 / B_KRow>{}, B_KRow)),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<LRepeat>{}, Number<LWaves>{}, Number<LPerWmma>{})),
|
||||
make_pass_through_transform(Number<B_K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 3>{}, Sequence<1, 2, 4>{}, Sequence<5>{}));
|
||||
}();
|
||||
|
||||
return b0_wave_desc;
|
||||
return MakeWmmaTileDescriptor<LRepeat, LWaves, LPerWmma>(B0BlockDesc_{});
|
||||
}
|
||||
|
||||
template <typename A1BlockDesc_AL0_M_AL1>
|
||||
@@ -356,6 +428,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3
|
||||
MRepeat,
|
||||
LRepeat,
|
||||
KPack,
|
||||
KInner,
|
||||
true>())>; // TransposeC (must be true to work), C' = B' x A'
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping is controlled by {M01, N01}
|
||||
|
||||
@@ -151,10 +151,20 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
static constexpr auto AK1Number = Number<AK1Value>{};
|
||||
static constexpr auto BK1Number = Number<BK1Value>{};
|
||||
|
||||
static constexpr index_t KPack = math::max(
|
||||
math::lcm(AK1Number, BK1Number),
|
||||
static constexpr index_t KPerWmmaBlk =
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.k_per_wmma);
|
||||
.k_per_blk;
|
||||
|
||||
static constexpr index_t KInnerA = ck::math::integer_divide_ceil(AK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInnerB = ck::math::integer_divide_ceil(BK1Value, KPerWmmaBlk);
|
||||
|
||||
static constexpr index_t KInner = ck::math::min(KInnerA, KInnerB);
|
||||
|
||||
static constexpr index_t KPack =
|
||||
KInner *
|
||||
WmmaSelector<ComputeTypeA, ComputeTypeB, AccDataType, MPerWmma, NPerWmma>::selected_wmma
|
||||
.k_per_wmma;
|
||||
|
||||
using ThisThreadBlock = ThisThreadBlock<BlockSize>;
|
||||
|
||||
@@ -218,6 +228,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
KPerBlock,
|
||||
MPerWmma,
|
||||
AK1Value,
|
||||
KPack,
|
||||
KInner,
|
||||
KPerWmmaBlk,
|
||||
UseBlockPaddingA,
|
||||
PermuteA,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
@@ -251,6 +264,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
KPerBlock,
|
||||
NPerWmma,
|
||||
BK1Value,
|
||||
KPack,
|
||||
KInner,
|
||||
KPerWmmaBlk,
|
||||
UseBlockPaddingB,
|
||||
PermuteB,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
@@ -563,7 +579,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_base
|
||||
NPerWmma,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
KPack>())>;
|
||||
KPack,
|
||||
KInner>())>;
|
||||
|
||||
// Used to create obj in global function and pass it to Run method
|
||||
using EpilogueCShuffle =
|
||||
|
||||
@@ -95,6 +95,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -136,6 +137,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -173,6 +175,7 @@ struct wmma_type<WmmaInstr::wmma_f16_16x16x16_f16,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
@@ -209,6 +212,7 @@ struct wmma_type<WmmaInstr::wmma_bf16_16x16x16_bf16,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 2;
|
||||
@@ -251,6 +255,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t src_a_data_size = 2;
|
||||
static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -301,6 +306,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
// static constexpr index_t acc_data_size = 4;
|
||||
@@ -339,6 +345,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf16_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -372,6 +379,7 @@ struct wmma_type<WmmaInstr::wmma_i32_16x16x16_iu8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
// static constexpr index_t src_a_data_size = 2;
|
||||
// static constexpr index_t src_b_data_size = 2;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
@@ -413,6 +421,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8f8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
@@ -448,6 +457,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f8bf8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
@@ -483,6 +493,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8f8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
@@ -518,6 +529,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_bf8bf8_gfx12,
|
||||
static constexpr index_t m_per_wmma = 16;
|
||||
static constexpr index_t n_per_wmma = 16;
|
||||
static constexpr index_t k_per_wmma = 16;
|
||||
static constexpr index_t k_per_blk = 8;
|
||||
static constexpr index_t acc_data_size = 4;
|
||||
static constexpr index_t acc_pack_number = 1;
|
||||
static constexpr index_t num_thread_per_subgroups = n_per_wmma;
|
||||
@@ -768,6 +780,8 @@ struct WmmaGemm
|
||||
|
||||
__device__ static constexpr index_t GetWaveSize() { return wmma_instr.wave_size; }
|
||||
|
||||
__device__ static constexpr index_t GetKPerWaveBlk() { return wmma_instr.k_per_blk; }
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
|
||||
@@ -24,16 +24,43 @@ template <typename GemmConfig, typename T>
|
||||
auto shuffle_b(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 4, 1, 3, 5});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 2, 3, 1, 4});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig, typename T>
|
||||
@@ -55,21 +82,46 @@ template <typename GemmConfig, typename T>
|
||||
auto shuffle_b_permuteN(const ck_tile::HostTensor<T>& t)
|
||||
{
|
||||
assert(t.get_lengths().size() == 2);
|
||||
|
||||
int n_ = t.get_lengths()[1];
|
||||
int k_ = t.get_lengths()[0];
|
||||
constexpr int divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
constexpr int NRepeat = GemmConfig::N_Tile / GemmConfig::N_Warp_Tile / GemmConfig::N_Warp;
|
||||
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
if(ck_tile::is_gfx12_supported())
|
||||
{
|
||||
constexpr int divisor = 2;
|
||||
constexpr int kABK1PerLane = 8;
|
||||
constexpr int kABK0PerLane = GemmConfig::K_Warp_Tile / divisor / kABK1PerLane;
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
kABK0PerLane,
|
||||
divisor,
|
||||
kABK1PerLane});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 6, 5, 2, 7});
|
||||
}
|
||||
else
|
||||
{
|
||||
int divisor = 1;
|
||||
if(ck_tile::is_gfx11_supported())
|
||||
{
|
||||
divisor = 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(is_wave32() == false);
|
||||
divisor = GemmConfig::N_Warp_Tile == 32 ? 2 : 4;
|
||||
}
|
||||
ck_tile::HostTensor<T> t_view({n_ / GemmConfig::N_Tile,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
NRepeat,
|
||||
k_ / GemmConfig::K_Warp_Tile,
|
||||
divisor,
|
||||
GemmConfig::K_Warp_Tile / divisor});
|
||||
std::copy(t.begin(), t.end(), t_view.begin());
|
||||
return ck_tile::reference_permute(t_view, {0, 3, 1, 4, 5, 2, 6});
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -79,6 +79,7 @@ struct WarpGemmAttributeWmma
|
||||
static constexpr index_t kM = Impl::kM;
|
||||
static constexpr index_t kN = Impl::kN;
|
||||
static constexpr index_t kK = Impl::kK;
|
||||
static constexpr index_t kCMLane = Impl::kCMLane;
|
||||
static constexpr index_t kKPerThread = Impl::kABK0PerLane * Impl::kABK1PerLane;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_ar_flatbr_bquant_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_aquant_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_universal_gemm_as_bs_bquant_cr.hpp"
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Common utilities for quantized GEMM block operations
|
||||
template <typename CDataType,
|
||||
typename WarpGemmType,
|
||||
index_t MIterPerWarp,
|
||||
index_t MWarp,
|
||||
index_t NIterPerWarp,
|
||||
index_t NWarp>
|
||||
struct BlockGemmQuantCommon
|
||||
{
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemmType::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -81,11 +82,11 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f = element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f = element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
@@ -100,21 +101,8 @@ struct BlockGemmWeightPreshuffleBQuantARegBRegCReg
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
|
||||
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
return BlockGemmQuantCommon<CDataType, WG, MIterPerWarp, MWarp, NIterPerWarp, NWarp>::
|
||||
MakeCBlockTile();
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -24,13 +25,11 @@ struct BlockGemmAQuantBase
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<AQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<AQDataType, float>)
|
||||
{
|
||||
@@ -348,7 +347,7 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
// Thread 0 can read AQ_tile[0, 0] from itself, AQ_tile[1,
|
||||
// 0] from thread 1, ..., and AQ_tile[3, 0] from thread 3.
|
||||
|
||||
constexpr uint32_t kTileRowsOfCPerThread = 4;
|
||||
constexpr uint32_t kTileRowsOfCPerThread = (get_warp_size() == 64) ? 4 : 8;
|
||||
decltype(threadIdx.x) pull_from_lane = 0;
|
||||
if constexpr(WarpGemm::kM == 16)
|
||||
{
|
||||
@@ -409,7 +408,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
// desired row coefficient
|
||||
auto& scale_reg = aq_block_tensor.get_thread_buffer()[src_reg_offset];
|
||||
|
||||
constexpr uint32_t kTileRows = 4;
|
||||
constexpr uint32_t kTileRows = (get_warp_size() == 64) ? 4 : 8;
|
||||
;
|
||||
constexpr uint32_t kTiledCMsPerWarp = WarpGemm::kCMLane * kTileRows;
|
||||
constexpr uint32_t reg_offset_for_row_data = c_row * WarpGemm::kCMLane;
|
||||
// Multiply by 4 because output is stored in tiles of 4
|
||||
@@ -543,20 +543,8 @@ struct AQuantBlockUniversalGemmAsBsCr : public BlockGemmAQuantBase<Problem_>
|
||||
public:
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
return c_block_tensor;
|
||||
return BlockGemmQuantCommon<CDataType, WarpGemm, MIterPerWarp, MWarp, NIterPerWarp, NWarp>::
|
||||
MakeCBlockTile();
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
|
||||
@@ -9,6 +9,7 @@
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/ops/gemm_quant/block/block_gemm_quant_common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -24,13 +25,11 @@ struct BlockGemmBQuantBase
|
||||
float scale_reg_f = 0.f;
|
||||
if constexpr(std::is_same_v<BQDataType, ck_tile::fp8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_fp8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_fp8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, ck_tile::bf8_t>)
|
||||
{
|
||||
scale_reg_f =
|
||||
ck_tile::element_wise::amd_assembly_bf8_to_fp32(static_cast<uint32_t>(scale));
|
||||
scale_reg_f = __builtin_amdgcn_cvt_f32_bf8(static_cast<uint32_t>(scale), 0);
|
||||
}
|
||||
else if constexpr(std::is_same_v<BQDataType, float>)
|
||||
{
|
||||
@@ -376,20 +375,8 @@ struct BQuantBlockUniversalGemmAsBsCr : public BlockGemmBQuantBase<Problem_>
|
||||
public:
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
|
||||
return c_block_tensor;
|
||||
return BlockGemmQuantCommon<CDataType, WarpGemm, MIterPerWarp, MWarp, NIterPerWarp, NWarp>::
|
||||
MakeCBlockTile();
|
||||
}
|
||||
|
||||
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
|
||||
|
||||
@@ -240,7 +240,10 @@ struct QuantGemmKernel
|
||||
return dim3(TilePartitioner::GridSize(M, N), 1, KBatch);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
|
||||
CK_TILE_HOST static auto BlockSize()
|
||||
{
|
||||
return is_wave32() ? dim3(kBlockSize / 2) : dim3(kBlockSize);
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr QuantGemmKernelArgs
|
||||
MakeKernelArgs(const QuantGemmHostArgs& hostArgs)
|
||||
|
||||
@@ -41,7 +41,8 @@ template <bool kPadM_,
|
||||
typename BQLayout_ = BLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool DoubleSmemBuffer_ = false,
|
||||
bool UsePersistentKernel_ = false>
|
||||
bool UsePersistentKernel_ = false,
|
||||
int VectorSize_ = 16>
|
||||
struct TileGemmQuantTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
@@ -50,7 +51,7 @@ struct TileGemmQuantTraits
|
||||
|
||||
static constexpr QuantType kQuantType = QuantType_;
|
||||
|
||||
static constexpr int _VectorSize = 16;
|
||||
static constexpr int _VectorSize = VectorSize_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
|
||||
Reference in New Issue
Block a user