This commit is contained in:
Sami Remes
2025-08-21 11:33:53 +00:00
parent abcf2f3c97
commit 3dc6b7c71a
2 changed files with 199 additions and 165 deletions

View File

@@ -121,7 +121,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
KPack,
true>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -147,6 +146,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
@@ -334,16 +334,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
b_scale_thread_desc.GetElementSpaceSize());
using AScaleBufferType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize()));
using BScaleBufferType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
b_scale_thread_desc.GetElementSpaceSize()));
StaticallyIndexedArray<AScaleBufferType, Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<BScaleBufferType, Number<2>{}> b_scale_thread_bufs;
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
c_scale_thread_desc.GetElementSpaceSize());
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
// Global prefetch 1
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
@@ -381,12 +380,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
});
// Local prefill 1
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
// Global prefetch 2
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
@@ -400,7 +399,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_bufs(I1));
a_scale_thread_bufs(I0));
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
@@ -420,7 +419,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs(I1));
b_scale_thread_bufs(I0));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
@@ -429,7 +428,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
// Double register buffer for non-scaled gemm computation
// 1. Reduce register pressure
// 2. Decouple the dependency between mfma instruction and scale-fma instruction
// 2. Decouple the dependency between mfma instruction and scale-fma instruction following.
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
2,
@@ -458,6 +457,33 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
});
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
// Fill first MFMA buffer
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(I0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(I0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
});
__builtin_amdgcn_sched_barrier(0);
// main body
@@ -468,8 +494,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
{
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
block_sync_lds();
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(local_read_buf));
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
@@ -477,88 +503,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) = c_scale_thread_buf[m0];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) = c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
// Clear buffer for new MFMA computation
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
});
// Apply scaling with packed FMA and accumulate to main buffer
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf,
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf,
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
@@ -584,15 +528,114 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs(local_read_buf));
b_scale_thread_buf(local_read_buf));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Update scales for next iteration using the loaded values
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0];
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[m0];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
// Compute offsets
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto a_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
constexpr auto b_local_buf_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
constexpr auto local_buf_id =
Number<mfma_reg_buf ^
((m0 * NRepeat + n0 + 1) / (MRepeat * NRepeat))>{};
// Clear the current mfma output buffer
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
constexpr auto m0_offset = (a_local_buf_offset + HotloopLocalBufSwitch * mfma_reg_buf) % 2;
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[local_buf_id][Number<a_thread_desc_.CalculateOffset(
make_tuple(m0_offset, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[local_buf_id][Number<b_thread_desc_.CalculateOffset(
make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
});
// Apply scaling
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
// .template AsType<AccDataType>()[Number<t>{}] *
// type_convert<AccDataType>(c_scale_thread_buf[m0]);
// });
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});
};
block_sync_lds();
static_for<0, KRepeat, 1>{}([&](auto k) {
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
a_block_buf.At(mfma_reg_buf),
a_thread_desc_,
make_tuple(m0, I0, k, I0),
a_thread_buf);
});
static_for<0, NRepeat, 1>{}([&](auto n0) {
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
b_block_buf.At(mfma_reg_buf),
b_thread_desc_,
make_tuple(n0, I0, k, I0),
b_thread_buf);
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_buf[mfma_reg_buf][m0] *
b_scale_thread_buf[mfma_reg_buf][I0];
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
}
LoopFunc(I0, I1);
LoopFunc(I1, I0);
@@ -602,65 +645,45 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
}
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) = c_scale_thread_buf[m0];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) = c_scale_thread_buf[m0];
// if constexpr(TailNum == TailNumber::Full)
// {
// static_for<0, MRepeat, 1>{}([&](auto m0) {
// static_for<0, NRepeat, 1>{}([&](auto n0) {
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
// .template AsType<AccDataType>()(Number<t>{}) = 0;
// });
// static_for<0, KRepeat, 1>{}([&](auto k0) {
// vector_type<ComputeDataType, KPack> a_thread_vec;
// vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, NRepeat, 1>{}([&](auto n0) {
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
// static_for<0, KPack, 1>{}([&](auto ik) {
// a_thread_vec.template AsType<ComputeDataType>()(ik) =
// a_thread_buf[Number<a_thread_desc_.CalculateOffset(
// make_tuple(m0, I0, k0, ik))>{}];
// b_thread_vec.template AsType<ComputeDataType>()(ik) =
// b_thread_buf[Number<b_thread_desc_.CalculateOffset(
// make_tuple(n0, I0, k0, ik))>{}];
// });
// Clear buffer for new MFMA computation
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
// using mfma_input_type =
// typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
});
// Apply scaling with packed FMA and accumulate to main buffer
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
});
});
});
// xdlops_gemm.template Run<>(
// a_thread_vec.template AsType<mfma_input_type>(),
// b_thread_vec.template AsType<mfma_input_type>(),
// c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
// });
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
// constexpr index_t c_offset =
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
// c_thread_buf(Number<c_offset>{}) +=
// c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
// .template AsType<AccDataType>()[Number<t>{}] *
// type_convert<AccDataType>(c_scale_thread_buf[m0]);
// });
// });
// });
__builtin_amdgcn_sched_barrier(0);
}
}

View File

@@ -973,8 +973,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
constexpr auto c_block_size =
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) +
b_block_space_size_aligned * sizeof(LDSTypeB)),
return math::max((2 * a_block_space_size_aligned * sizeof(LDSTypeA) +
2 * b_block_space_size_aligned * sizeof(LDSTypeB)),
c_block_size * sizeof(CShuffleDataType));
}
@@ -1327,15 +1327,26 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
// LDS allocation for A and B: be careful of alignment
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
// Cast after lds
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
auto a_block_buf_1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_buf_2 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeA*>(p_shared) + a_block_space_size_aligned, a_block_desc_ak0_m_ak1.GetElementSpaceSize());
auto a_block_bufs = make_tuple(a_block_buf_1, a_block_buf_2);
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
auto b_block_buf_1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
2 * a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_buf_2 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
static_cast<LDSTypeB*>(p_shared) +
2 * a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB) +
b_block_space_size_aligned,
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
auto b_block_bufs = make_tuple(b_block_buf_1, b_block_buf_2);
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
@@ -1410,13 +1421,13 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
a_block_desc_ak0_m_ak1,
a_blockwise_copy,
a_grid_buf,
a_block_buf,
a_block_bufs,
a_block_slice_copy_step,
b_grid_desc_bk0_n_bk1,
b_block_desc_bk0_n_bk1,
b_blockwise_copy,
b_grid_buf,
b_block_buf,
b_block_bufs,
b_block_slice_copy_step,
c_scale_thread_desc,