Conv:TF32: add more instances - 1 (#2867)

* conv:tf32:add more instances
* add instances of device_grouped_conv_fwd_xdl_f32_comp_instances
* add instances of device_grouped_conv_fwd_xdl_f32_tf32_mem_instances
* add instances of device_grouped_conv_fwd_xdl_large_tensor_f32_tf32_instances
* remove gnhwc/ngchw/ngcdhw instances
This commit is contained in:
yinglu
2025-09-25 09:27:18 +08:00
committed by GitHub
parent f076f207ce
commit df97a286d5
92 changed files with 4273 additions and 443 deletions

View File

@@ -54,6 +54,9 @@ struct BlockwiseGemmXdlops_pipeline_base
static constexpr auto xdlops_gemm =
XdlopsGemm<ComputeDataType, MPerXDL, NPerXDL, KPack, ComputeDataType, TransposeC>{};
using ComputeDataTypeBuf =
conditional_t<std::is_same<ComputeDataType, ck::tf32_t>::value, float, ComputeDataType>;
static constexpr index_t AMmaKStride = KPack;
static constexpr index_t BMmaKStride = KPack;
@@ -376,7 +379,7 @@ struct BlockwiseGemmXdlops_pipeline_base
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}, xdlops_gemm.GetRegSizePerXdlops()));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
@@ -386,7 +389,7 @@ struct BlockwiseGemmXdlops_pipeline_base
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,

View File

@@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
@@ -185,9 +187,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
@@ -240,20 +242,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -301,20 +303,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -439,6 +441,8 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
@@ -486,9 +490,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
@@ -551,20 +555,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -640,20 +644,20 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -704,7 +708,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
I1));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
@@ -714,7 +718,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,

View File

@@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::a_block_desc_m0_m1_m2_k;
using Base::b_block_desc_n0_n1_n2_k;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t AMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
static constexpr index_t BMmaKStride = xdlops_gemm.K0PerXdlops * KPack;
@@ -222,10 +224,12 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
// sizeof(ComputeDataTypeBuf) /
// sizeof(BDataType)
// ? sizeof(ComputeDataTypeBuf) /
// sizeof(ADataType) : sizeof(ComputeDataTypeBuf)
// / sizeof(BDataType);
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
@@ -351,9 +355,9 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
{
__builtin_amdgcn_sched_barrier(0);
// assume kperblock = scaleblockk
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
@@ -516,17 +520,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
@@ -535,7 +539,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
@@ -646,17 +650,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
@@ -665,7 +669,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
@@ -737,17 +741,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
@@ -756,7 +760,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
@@ -791,17 +795,17 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat / num_scale_k_block, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0,
I0,
kscale0 * KRepeat / num_scale_k_block + k0,
ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0,
I0,
@@ -810,7 +814,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
@@ -842,7 +846,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::b_thread_desc_;
using Base::c_thread_desc_;
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
@@ -852,7 +856,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,

View File

@@ -140,6 +140,8 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
using Base::AMmaKStride;
using Base::BMmaKStride;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t PrefetchStages = 1;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
@@ -205,12 +207,12 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
{
// assume kperblock = scaleblockk
ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1
@@ -279,20 +281,20 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
@@ -360,20 +362,20 @@ struct BlockwiseGemmXdlops_pipeline_v1_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),

View File

@@ -141,6 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
using Base::BMmaKStride;
using Base::WaveSize;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t WgpPerCU =
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
@@ -225,9 +227,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
@@ -284,20 +286,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -355,20 +357,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -410,20 +412,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -461,20 +463,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -628,6 +630,8 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
using Base::b_block_desc_n0_n1_n2_k;
using Base::WaveSize;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
@@ -716,9 +720,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
@@ -786,20 +790,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -885,20 +889,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -961,20 +965,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -1037,20 +1041,20 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -1129,7 +1133,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
I1));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
@@ -1139,7 +1143,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,

View File

@@ -143,6 +143,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::BMmaKStride;
using Base::WaveSize;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t WgpPerCU =
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
@@ -257,9 +259,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
{
// assume kperblock = scaleblockk
ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
@@ -351,20 +353,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
@@ -457,20 +459,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
@@ -547,20 +549,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
@@ -605,20 +607,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
static_for<0, NRepeat, 1>{}([&](auto n0) {
c_thread_buf_per_scale.Clear();
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),

View File

@@ -141,6 +141,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
using Base::BMmaKStride;
using Base::WaveSize;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t WgpPerCU =
(4 * WaveSize / BlockSize) >= 1 ? 4 * WaveSize / BlockSize : 1;
static constexpr index_t FullMemBandPrefetchStages = math::integer_divide_ceil(
@@ -225,9 +227,9 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
@@ -285,20 +287,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -356,20 +358,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -411,20 +413,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -462,20 +464,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -629,6 +631,8 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
using Base::b_block_desc_n0_n1_n2_k;
using Base::WaveSize;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t NumMacClusters = CK_EXPERIMENTAL_INTER_WAVE_SCHEDULING_MAC_CLUSTERS;
static constexpr index_t KPerInnerLoop = math::max(KPerThread / NumMacClusters, KPack);
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
@@ -732,12 +736,12 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
{
ignore = num_loop_per_scale;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1
@@ -821,20 +825,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -942,20 +946,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -1039,20 +1043,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -1123,20 +1127,20 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
static_for<0, KPerInnerLoop, KPack>{}([&](auto k_) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, k_ + ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, k_ + ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -1223,7 +1227,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
I1));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,
@@ -1233,7 +1237,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_b_scale<BlockGemmPipelineScheduler::Inter
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPerInnerLoop>,

View File

@@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
@@ -196,10 +198,10 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
// sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
// ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
// : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
@@ -295,9 +297,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
index_t num_loop) const
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
@@ -364,20 +366,20 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -424,20 +426,20 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));

View File

@@ -143,6 +143,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::AMmaKStride;
using Base::BMmaKStride;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
@@ -196,10 +198,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
// sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
// ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
// : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
@@ -329,9 +331,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
"Pipeline v3 only support scaleblocksliceN=1");
// assume kperblock = scaleblockk
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
@@ -476,20 +478,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
@@ -578,20 +580,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),

View File

@@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
using Base::AMmaKStride;
using Base::BMmaKStride;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
@@ -195,10 +197,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
// stage 1
// Separate this part?
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
// sizeof(ComputeDataType) / sizeof(BDataType)
// ? sizeof(ComputeDataType) / sizeof(ADataType)
// : sizeof(ComputeDataType) / sizeof(BDataType);
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataTypeBuf) / sizeof(ADataType) >
// sizeof(ComputeDataTypeBuf) / sizeof(BDataType)
// ? sizeof(ComputeDataTypeBuf) / sizeof(ADataType)
// : sizeof(ComputeDataTypeBuf) / sizeof(BDataType);
constexpr auto num_mfma_stage1 = num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
constexpr auto num_mfma_per_issue =
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
@@ -307,13 +309,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
{
__builtin_amdgcn_sched_barrier(0);
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// B scale buffer
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_scale_thread_desc.GetElementSpaceSize());
// Global prefetch 1
@@ -429,20 +431,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -491,20 +493,20 @@ struct BlockwiseGemmXdlops_pipeline_v3_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));

View File

@@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 2;
static constexpr index_t GlobalBufferNum = 1;
@@ -264,9 +266,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
@@ -369,22 +371,22 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf]
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -439,20 +441,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -492,20 +494,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -524,20 +526,20 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));

View File

@@ -142,6 +142,8 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
using Base::AMmaKStride;
using Base::BMmaKStride;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 2;
static constexpr index_t GlobalBufferNum = 1;
@@ -277,13 +279,13 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
index_t num_loop,
index_t num_loop_per_scale) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// B scale buffer
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_thread_buf), Number<2>{}> a_thread_bufs;
@@ -478,22 +480,22 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf]
[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf]
[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -549,20 +551,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -603,20 +605,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -635,20 +637,20 @@ struct BlockwiseGemmXdlops_pipeline_v4_b_scale<BlockGemmPipelineScheduler::Intra
static_for<0, KRepeat, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_bufs[mfma_reg_buf][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_bufs[mfma_reg_buf][Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));

View File

@@ -144,6 +144,8 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
using Base::AMmaKStride;
using Base::BMmaKStride;
using ComputeDataTypeBuf = typename Base::ComputeDataTypeBuf;
static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 2;
@@ -346,9 +348,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
CThreadBuffer& c_thread_buf,
index_t num_loop) const
{
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataTypeBuf>(
b_thread_desc_.GetElementSpaceSize());
// Global prefetch 1
@@ -405,8 +407,8 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
do
{
auto LoopFunc = [&](auto vmem_buf) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KRepeat, 1>{}([&](auto k0) {
if constexpr(k0 == (KRepeat - 1))
@@ -427,18 +429,18 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, ik))>{}];
});
static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
typename vector_type<ComputeDataTypeBuf,
xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
@@ -481,8 +483,8 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
}
// tail
auto ReadWriteCompFunc = [&](auto vmem_buf) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KRepeat, 1>{}([&](auto k0) {
if constexpr(k0 == (KRepeat - 1))
@@ -497,18 +499,18 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, ik))>{}];
});
static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -540,25 +542,25 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
HotLoopScheduler();
};
auto ReadCompFunc = [&]() {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> a_thread_vec;
vector_type<ComputeDataTypeBuf, KPack> b_thread_vec;
static_for<0, KRepeat - 1, 1>{}([&](auto k0) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, I0, ik))>{}];
});
static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -591,16 +593,16 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
static_for<0, MRepeat, 1>{}([&](auto m0) {
static_for<0, NRepeat, 1>{}([&](auto n0) {
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
a_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(m0, I0, I0, ik))>{}];
});
static_for<0, KPack, 1>{}([&](auto ik) {
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
b_thread_vec.template AsType<ComputeDataTypeBuf>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(n0, I0, I0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
typename vector_type<ComputeDataTypeBuf, xdlops_gemm.K1PerXdlops>::type;
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
@@ -637,7 +639,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
make_naive_tensor_descriptor_packed(make_tuple(Number<NRepeat>{}, I1, I1, Number<KPack>{}));
using AThreadCopy = ThreadwiseTensorSliceTransfer_v4<ADataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(a_block_desc_m0_m1_m2_k),
decltype(a_thread_desc_),
Sequence<1, 1, 1, KPack>,
@@ -647,7 +649,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
A_K1>;
using BThreadCopy = ThreadwiseTensorSliceTransfer_v4<BDataType,
ComputeDataType,
ComputeDataTypeBuf,
decltype(b_block_desc_n0_n1_n2_k),
decltype(b_thread_desc_),
Sequence<1, 1, 1, KPack>,