mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Wmma support for gemm_ab_scale (#3314)
* Support gemm_ab_scale: - Add tests - Integrate scaling implementation in multiple D - Generalize existing b_scale for ab_scale - Add instances - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK - Add support for all layouts supported by xdl - Fix splitk xdl * Fix copyright * Wmma support for gemm_blockscale_wp (#3315) * Support for preshuffle with ab scale - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale - add support for AScaleLayout amnd BScaleLayout (can be different from ALayout and BLayout, respectively) - add Run method in v1 pipeline to support preshuffle + scaling - add support for preshuffle gemms in common invoker - Add splitk support * Fix copyright header
This commit is contained in:
@@ -109,65 +109,145 @@ struct BlockwiseGemmWmmaops_pipeline_base
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t ScaleSliceSizeN,
|
||||
template <index_t ScaleSliceSizeMN,
|
||||
index_t ScaleSliceStrideMN,
|
||||
index_t ScaleSliceSizeK,
|
||||
index_t NWaves,
|
||||
index_t ScaleBlockK,
|
||||
index_t NumberOfBuffers,
|
||||
index_t RegSizePerWmma,
|
||||
typename GridDesc,
|
||||
typename ThreadCopy,
|
||||
typename GridBuffer,
|
||||
typename ThreadStaticBuffer,
|
||||
typename BScaleThreadDesc>
|
||||
struct BScale
|
||||
typename ThreadDesc>
|
||||
struct ABScale
|
||||
{
|
||||
__device__ BScale(GridDesc b_scale_grid_desc_,
|
||||
ThreadCopy b_scale_thread_copy_,
|
||||
GridBuffer b_scale_grid_buf_)
|
||||
: b_scale_thread_copy(b_scale_thread_copy_),
|
||||
b_scale_grid_desc(b_scale_grid_desc_),
|
||||
b_scale_grid_buf(b_scale_grid_buf_) {};
|
||||
__device__ ABScale(GridDesc scale_grid_desc_,
|
||||
ThreadCopy scale_thread_copy_,
|
||||
GridBuffer scale_grid_buf_)
|
||||
: scale_thread_copy(scale_thread_copy_),
|
||||
scale_grid_desc(scale_grid_desc_),
|
||||
scale_grid_buf(scale_grid_buf_) {};
|
||||
|
||||
static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{});
|
||||
static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block;
|
||||
|
||||
static constexpr auto b_scale_thread_desc = BScaleThreadDesc{};
|
||||
static constexpr index_t num_slice_mn = ScaleSliceSizeMN;
|
||||
static constexpr index_t num_slice_k = ScaleSliceSizeK;
|
||||
static constexpr index_t reg_size_per_wmma = RegSizePerWmma;
|
||||
|
||||
static constexpr auto b_scale_thread_copy_step =
|
||||
make_tuple(make_multi_index(NWaves * NPerWmma, 0),
|
||||
make_multi_index(-NPerBlock, 0),
|
||||
make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK));
|
||||
static constexpr auto scale_thread_desc = ThreadDesc{};
|
||||
|
||||
static constexpr auto scale_thread_copy_step =
|
||||
make_tuple(make_multi_index(ScaleSliceStrideMN, 0),
|
||||
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0),
|
||||
make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN,
|
||||
ScaleSliceSizeK));
|
||||
|
||||
template <index_t NBuffer>
|
||||
__device__ void GlobalLoad(bool cond)
|
||||
{
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(n0, Number<0>{}),
|
||||
b_scale_thread_bufs(Number<NBuffer>{}));
|
||||
static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) {
|
||||
scale_thread_copy.Run(scale_grid_desc,
|
||||
scale_grid_buf,
|
||||
scale_thread_desc,
|
||||
make_tuple(m0 * Number<RegSizePerWmma>{}, Number<0>{}),
|
||||
scale_thread_bufs(Number<NBuffer>{}));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step.At(Number<0>{}));
|
||||
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
|
||||
scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if(cond)
|
||||
{
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step.At(Number<2>{}));
|
||||
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
|
||||
scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc,
|
||||
b_scale_thread_copy_step.At(Number<1>{}));
|
||||
scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc,
|
||||
scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
ThreadCopy b_scale_thread_copy;
|
||||
GridDesc b_scale_grid_desc;
|
||||
GridBuffer b_scale_grid_buf;
|
||||
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> b_scale_thread_bufs;
|
||||
ThreadCopy scale_thread_copy;
|
||||
GridDesc scale_grid_desc;
|
||||
GridBuffer scale_grid_buf;
|
||||
StaticallyIndexedArray<ThreadStaticBuffer, Number<NumberOfBuffers>{}> scale_thread_bufs;
|
||||
};
|
||||
|
||||
template <typename AScaleStruct, typename BScaleStruct>
|
||||
struct CScale
|
||||
{
|
||||
__device__ CScale() {}
|
||||
|
||||
static constexpr auto reg_size_per_wmma =
|
||||
ck::is_same_v<BScaleStruct, Empty> && ck::is_same_v<AScaleStruct, Empty>
|
||||
? 1
|
||||
: wmma_gemm.GetRegSizePerWmma();
|
||||
static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple(
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{},
|
||||
Number<AScaleStruct::num_slice_mn>{},
|
||||
Number<BScaleStruct::num_slice_mn>{}));
|
||||
using CScaleThreadDesc = decltype(c_scale_thread_desc);
|
||||
static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{});
|
||||
static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{});
|
||||
static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{});
|
||||
using ThreadStaticBuffer = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize()));
|
||||
|
||||
__device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct)
|
||||
{
|
||||
using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc);
|
||||
using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc);
|
||||
|
||||
static_for<0, num_scale_m_block, 1>{}([&](auto m0) {
|
||||
static_for<0, num_scale_n_block, 1>{}([&](auto n0) {
|
||||
static_for<0, num_scale_k_block, 1>{}([&](auto k0) {
|
||||
constexpr index_t c_offset =
|
||||
CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0));
|
||||
constexpr index_t a_offset =
|
||||
AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0));
|
||||
constexpr index_t b_offset =
|
||||
BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0));
|
||||
|
||||
c_scale_thread_bufs(I0)(Number<c_offset>{}) =
|
||||
a_scale_struct.scale_thread_bufs(I0)[Number<a_offset>{}] *
|
||||
b_scale_struct.scale_thread_bufs(I0)[Number<b_offset>{}];
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void Clear()
|
||||
{
|
||||
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t k_index, index_t m_index, index_t n_index, typename CThreadBuf>
|
||||
__device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf)
|
||||
{
|
||||
static_for<0, reg_size_per_wmma, 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t));
|
||||
constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple(
|
||||
k_index,
|
||||
(m_index * num_scale_m_block / MRepeat) % num_scale_m_block +
|
||||
(Number<t / (reg_size_per_wmma / AScaleStruct::reg_size_per_wmma)>{}) %
|
||||
AScaleStruct::reg_size_per_wmma,
|
||||
(n_index * num_scale_n_block / NRepeat) % num_scale_n_block));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()[Number<t>{}] *
|
||||
type_convert<AccDataType>(c_scale_thread_bufs(I0)[Number<cscale_offset>{}]);
|
||||
});
|
||||
}
|
||||
|
||||
StaticallyIndexedArray<ThreadStaticBuffer, Number<1>{}> c_scale_thread_bufs;
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr, AccDataType, 1, reg_size_per_wmma, true>
|
||||
c_thread_buf_per_scale;
|
||||
};
|
||||
|
||||
__host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; }
|
||||
|
||||
@@ -174,7 +174,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename BScaleStruct>
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
@@ -188,7 +190,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// BScaleThreadCopy
|
||||
AScaleStruct&,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
@@ -207,6 +209,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
// Local prefill 1
|
||||
@@ -217,6 +220,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto blockwise_gemm_func = [&]() {
|
||||
// Local load
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
@@ -245,7 +249,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
b_scale_struct.scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
@@ -366,6 +370,189 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
|
||||
!ck::is_same_v<BScaleStruct, Empty>,
|
||||
bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
static constexpr auto NumScaleKBlock =
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
Base::a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
Base::b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
|
||||
auto c_scale_struct = CScaleStruct{};
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
auto blockwise_gemm_func = [&]() {
|
||||
// Local load
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
Base::a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
Base::b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<Base::a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<Base::b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
|
||||
|
||||
block_sync_lds();
|
||||
blockwise_gemm_func();
|
||||
|
||||
block_sync_lds();
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
blockwise_gemm_func();
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
// A[MRepeat, I1, I1, KPack]
|
||||
static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed(
|
||||
@@ -528,6 +715,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
return TailNumber::Full;
|
||||
}
|
||||
|
||||
template <typename AScaleStruct, typename BScaleStruct>
|
||||
struct KLoopParams
|
||||
{
|
||||
static constexpr auto KRepeatNoScale = 1;
|
||||
static constexpr auto NumScaleKBlock =
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
|
||||
static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KLoopParams<Empty, Empty>
|
||||
{
|
||||
static constexpr index_t KRepeatNoScale = KRepeatPerCluster;
|
||||
static constexpr index_t NumScaleKBlock = 1;
|
||||
static constexpr index_t KRepeatPerNumScaleKBlock = 1;
|
||||
};
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
@@ -543,7 +747,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename BScaleStruct>
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
@@ -557,7 +763,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// BScaleThreadCopy
|
||||
AScaleStruct&,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
@@ -576,6 +782,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
// Local prefill 1
|
||||
@@ -615,7 +822,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0_offset + k0_inner, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(I0)[Number<
|
||||
b_scale_struct.scale_thread_bufs(I0)[Number<
|
||||
n0 * BScaleStruct::num_scale_k_block +
|
||||
(k0_offset + k0_inner) / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
@@ -704,6 +911,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
__builtin_amdgcn_s_setprio(0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
@@ -982,7 +1190,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename BScaleStruct>
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
@@ -996,7 +1206,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
BBlockBuffer&,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// BScaleThreadCopy
|
||||
AScaleStruct&,
|
||||
BScaleStruct&,
|
||||
index_t num_loop,
|
||||
index_t) const
|
||||
@@ -1319,6 +1529,248 @@ struct BlockwiseGemmWmmaops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
|
||||
!ck::is_same_v<BScaleStruct, Empty>,
|
||||
bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc&,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer&,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
static constexpr auto NumScaleKBlock =
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(b_thread_buf), Number<2>{}> b_thread_bufs;
|
||||
constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0);
|
||||
|
||||
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
|
||||
auto c_scale_struct = CScaleStruct{};
|
||||
|
||||
auto gemm_core_func = [&](auto reg_buf) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_bufs[reg_buf]
|
||||
[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
I0,
|
||||
I0,
|
||||
n0,
|
||||
I0,
|
||||
k_index,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
auto a_local_prefetch_func = [&]() {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
// Global prefetch A1 B1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I0));
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
// Local prefill A1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
// Global prefetch A2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
// Local prefetch A1
|
||||
block_sync_lds();
|
||||
a_local_prefetch_func();
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) {
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(local_read_buf));
|
||||
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf);
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
|
||||
a_scale_struct.template GlobalLoad<0>(
|
||||
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>(
|
||||
(i + 2 + wmma_reg_buf) % num_loop_per_scale == 0);
|
||||
|
||||
gemm_core_func(wmma_reg_buf);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// loop prefetch copy
|
||||
a_local_prefetch_func();
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
// HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
};
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
|
||||
i += 2;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
b_blockwise_copy.Run(b_grid_desc,
|
||||
b_grid_buf,
|
||||
b_block_desc_k0_n0_n1_n2_k1,
|
||||
b_block_origin_idx,
|
||||
b_thread_bufs(I1));
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
|
||||
gemm_core_func(I0);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// tail Local Prefetch A1
|
||||
a_local_prefetch_func();
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
gemm_core_func(I1);
|
||||
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
gemm_core_func(I0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
static constexpr auto b_thread_desc_ =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(Number<KPack / B_K1 / B_KRow>{},
|
||||
|
||||
@@ -123,6 +123,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
KInner,
|
||||
TransposeC>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::I2;
|
||||
using Base::I3;
|
||||
|
||||
using Base::A_K1;
|
||||
using Base::A_KRow;
|
||||
@@ -322,7 +325,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_scale_struct.b_scale_thread_bufs(
|
||||
b_scale_struct.scale_thread_bufs(
|
||||
I0)[Number<n0 * BScaleStruct::num_scale_k_block +
|
||||
k0 / BScaleStruct::num_scale_krepeat>{}],
|
||||
b_thread_desc_,
|
||||
@@ -348,7 +351,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename BScaleStruct>
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<ck::is_same_v<AScaleStruct, Empty>, bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
@@ -362,7 +367,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
// BScaleThreadCopy
|
||||
AScaleStruct&,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
@@ -383,6 +388,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
// Local prefill 1
|
||||
@@ -611,6 +617,339 @@ struct BlockwiseGemmWmmaops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
|
||||
}
|
||||
}
|
||||
|
||||
template <bool HasMainLoop,
|
||||
TailNumber TailNum,
|
||||
typename AGridDesc,
|
||||
typename ABlockDesc,
|
||||
typename ABlockTransfer,
|
||||
typename AGridBuffer,
|
||||
typename ABlockBuffer,
|
||||
typename ABlockTransferStep,
|
||||
typename BGridDesc,
|
||||
typename BBlockDesc,
|
||||
typename BBlockTransfer,
|
||||
typename BGridBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename BBlockTransferStep,
|
||||
typename CThreadBuffer,
|
||||
typename AScaleStruct,
|
||||
typename BScaleStruct,
|
||||
typename enable_if<!ck::is_same_v<AScaleStruct, Empty> &&
|
||||
!ck::is_same_v<BScaleStruct, Empty>,
|
||||
bool>::type = false>
|
||||
__device__ void Run(const AGridDesc& a_grid_desc,
|
||||
const ABlockDesc& a_block_desc,
|
||||
ABlockTransfer& a_blockwise_copy,
|
||||
const AGridBuffer& a_grid_buf,
|
||||
ABlockBuffer& a_block_buf,
|
||||
const ABlockTransferStep& a_block_copy_step,
|
||||
const BGridDesc& b_grid_desc,
|
||||
const BBlockDesc& b_block_desc,
|
||||
BBlockTransfer& b_blockwise_copy,
|
||||
const BGridBuffer& b_grid_buf,
|
||||
BBlockBuffer& b_block_buf,
|
||||
const BBlockTransferStep& b_block_copy_step,
|
||||
CThreadBuffer& c_thread_buf,
|
||||
AScaleStruct& a_scale_struct,
|
||||
BScaleStruct& b_scale_struct,
|
||||
index_t num_loop,
|
||||
index_t num_loop_per_scale) const
|
||||
{
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk();
|
||||
static constexpr auto NumScaleKBlock =
|
||||
Number<ck::math::max(AScaleStruct::num_slice_k, BScaleStruct::num_slice_k)>{};
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeA>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeTypeB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
using CScaleStruct = typename Base::template CScale<AScaleStruct, BScaleStruct>;
|
||||
auto c_scale_struct = CScaleStruct{};
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
// Scales global load
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1);
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// Global prefetch 2, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
}
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// Local prefetch 1
|
||||
block_sync_lds();
|
||||
|
||||
auto local_load_func = [&]() {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, m0, k0, I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, n0, k0, I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
local_load_func();
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Main body, perform when at least 3 loops exist.
|
||||
if constexpr(HasMainLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
wmma_gemm.Run(a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
block_sync_lds();
|
||||
|
||||
local_load_func();
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 2));
|
||||
}
|
||||
|
||||
// Pre-tail, perform when at least 2 loops exist.
|
||||
if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full)
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
|
||||
// No RunRead or MoveSrcSliceWindow here, already finished them all!
|
||||
a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
c_scale_struct.Load(a_scale_struct, b_scale_struct);
|
||||
block_sync_lds();
|
||||
|
||||
local_load_func();
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
// Tail, always perform.
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) {
|
||||
c_scale_struct.Clear();
|
||||
static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeTypeA, KPack / A_KRow / KInner> a_thread_vec;
|
||||
vector_type<ComputeTypeB, KPack / B_KRow / KInner> b_thread_vec;
|
||||
static_for<0, KInner, 1>{}([&](auto k_inner) {
|
||||
static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
a_thread_vec.template AsType<ComputeTypeA>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / A_K1>{},
|
||||
m0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % A_K1>{}))>{}];
|
||||
});
|
||||
static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) {
|
||||
constexpr index_t kk = ik + k_inner * KPerWaveBlock;
|
||||
constexpr index_t k_index =
|
||||
kscale0 * (KRepeat / NumScaleKBlock) + k0;
|
||||
b_thread_vec.template AsType<ComputeTypeB>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(Number<kk / B_K1>{},
|
||||
n0,
|
||||
k_index,
|
||||
I0,
|
||||
I0,
|
||||
I0,
|
||||
Number<kk % B_K1>{}))>{}];
|
||||
});
|
||||
|
||||
using wmma_input_type_a =
|
||||
typename vector_type<ComputeTypeA, WmmaK / A_KRow>::type;
|
||||
using wmma_input_type_b =
|
||||
typename vector_type<ComputeTypeB, WmmaK / B_KRow>::type;
|
||||
|
||||
wmma_gemm.Run(
|
||||
a_thread_vec.template AsType<wmma_input_type_a>(),
|
||||
b_thread_vec.template AsType<wmma_input_type_b>(),
|
||||
c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<0>{}));
|
||||
});
|
||||
});
|
||||
c_scale_struct.template UpdateCThreadBuf<kscale0, m0, n0>(c_thread_buf);
|
||||
});
|
||||
});
|
||||
});
|
||||
// Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
using Base::a_thread_copy_;
|
||||
using Base::a_thread_desc_;
|
||||
|
||||
Reference in New Issue
Block a user