Another alternative, this is numerically correct, but lets compiler do more work to enable the double buffering

This commit is contained in:
Sami Remes
2025-09-11 13:38:29 +00:00
parent 180a436cca
commit 64edaacebe
2 changed files with 115 additions and 207 deletions

View File

@@ -121,7 +121,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
KPack,
true>;
using Base::I0;
using Base::I1;
using Base::KRepeat;
using Base::xdlops_gemm;
using typename Base::HotLoopInstList;
@@ -144,10 +143,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using Base::AMmaKStride;
using Base::BMmaKStride;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
static constexpr index_t PrefetchStages = 2;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
{
@@ -331,19 +329,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
"Pipeline v3 only support scaleblocksliceN=1");
// assume kperblock = scaleblockk
using AThreadBufType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize()));
using BThreadBufType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize()));
AThreadBufType a_thread_buf{};
BThreadBufType b_thread_buf{};
using AScaleBufferType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize()));
using BScaleBufferType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
b_scale_thread_desc.GetElementSpaceSize()));
StaticallyIndexedArray<AScaleBufferType, Number<1>{}> a_scale_thread_bufs;
StaticallyIndexedArray<BScaleBufferType, Number<1>{}> b_scale_thread_bufs;
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
a_thread_desc_.GetElementSpaceSize());
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
b_thread_desc_.GetElementSpaceSize());
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
a_scale_thread_desc.GetElementSpaceSize());
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
b_scale_thread_desc.GetElementSpaceSize());
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
c_scale_thread_desc.GetElementSpaceSize());
@@ -359,7 +352,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_bufs(I0));
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
@@ -379,12 +372,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs(I0));
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
});
// Local prefill 1
@@ -403,7 +396,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_bufs(I0));
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
a_scale_thread_copy_step.At(Number<0>{}));
});
@@ -423,19 +416,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs(I0));
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
// Initialize C
c_thread_buf.Clear();
// Double register buffer for non-scaled gemm computation
// 1. Reduce register pressure
// 2. Decouple the dependency between mfma instruction and scale-fma instruction following.
// Double buffer for c_thread_buf_per_scale to enable temporal decoupling
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
AccDataType,
2,
2, // Double buffer
xdlops_gemm.GetRegSizePerXdlops(),
true>
c_thread_buf_per_scale;
@@ -461,33 +452,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
});
});
// Clear the first MFMA buffer
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(I0).template AsType<AccDataType>()(
Number<t>{}) = 0;
});
// Fill first MFMA buffer with index I0, this output used in the first part of main loop for
// scale-FMA
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
[Number<a_thread_desc_.CalculateOffset(make_tuple(I0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
[Number<b_thread_desc_.CalculateOffset(make_tuple(I0, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(I0));
});
__builtin_amdgcn_sched_barrier(0);
// main body
@@ -507,72 +471,26 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_bufs(I0));
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
});
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_bufs(I0));
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[m0];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
// Calculate buffer offsets using future tile approach
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
// Use double buffering with temporal offset to decouple MFMA and scaling
constexpr auto buffer_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
// Calculate future tile data offsets
constexpr auto a_future_tile_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
constexpr auto b_future_tile_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
// Clear the MFMA output buffer for future tile
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
c_thread_buf_per_scale.GetVectorTypeReference(Number<buffer_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
// Compute MFMA for future tile
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
static_for<0, KPack, 1>{}([&](auto ik) {
// Use future tile offsets for MFMA computation
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(a_future_tile_offset, I0, k0, ik))>{}];
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(b_future_tile_offset, I0, k0, ik))>{}];
make_tuple(n0, I0, k0, ik))>{}];
});
using mfma_input_type =
@@ -583,24 +501,16 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(
Number<mfma_buf_offset>{}));
Number<buffer_offset>{}));
});
// Run the element-wise FMA with data from previous iteration buffer
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale // mfma output from previous iteration
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec
.template AsType<pk_fma_type>()[Number<0>{}], // scales c=a*b
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale
.GetVectorTypeReference(Number<buffer_offset>{})
.template AsType<AccDataType>()(Number<t>{}) *
type_convert<AccDataType>(c_scale_thread_buf[m0]);
});
});
});
@@ -625,90 +535,89 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
});
});
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) =
a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
});
HotLoopScheduler();
__builtin_amdgcn_sched_barrier(0);
// Compute scale values early to allow better instruction scheduling
static_for<0, MRepeat, 1>{}([&](auto m0) {
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
});
// Prefetch scale data early to overlap with MFMA computation
static_for<0, MRepeat, 1>{}([&](auto m0) {
a_scale_thread_copy.Run(a_scale_grid_desc,
a_scale_grid_buf,
a_scale_thread_desc,
make_tuple(m0, I0),
a_scale_thread_buf);
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
});
if constexpr(NumKBlockPerScale == 1)
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
}
else
{
a_scale_thread_copy.MoveSrcSliceWindow(
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
}
b_scale_thread_copy.Run(b_scale_grid_desc,
b_scale_grid_buf,
b_scale_thread_desc,
make_tuple(I0, I0),
b_scale_thread_buf);
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
i += 1;
} while(i < (num_loop - 1));
}
// __builtin_amdgcn_sched_barrier(0);
// tail
if constexpr(TailNum == TailNumber::Full)
{
static_for<0, MRepeat, 1>{}([&](auto m0) {
vector_type<AccDataType, 2> c_scale_thread_vec;
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
c_scale_thread_buf[m0];
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
c_scale_thread_buf[m0];
static_for<0, NRepeat, 1>{}([&](auto n0) {
// Calculate buffer offsets using the same future tile approach
constexpr auto mfma_buf_offset =
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
constexpr auto scale_buf_offset =
// Use double buffering with temporal offset in tail section as well
constexpr auto buffer_offset =
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
// Calculate future tile data offsets
constexpr auto a_future_tile_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
constexpr auto b_future_tile_offset =
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<buffer_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
});
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
// Skip MFMA computation for the last tile to avoid out-of-bounds
if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1))))
{
// Clear the MFMA buffer for future tile computation
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
.template AsType<AccDataType>()(Number<t>{}) = 0;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(m0, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(n0, I0, k0, ik))>{}];
});
// Compute MFMA for future tile
static_for<0, KRepeat, 1>{}([&](auto k0) {
vector_type<ComputeDataType, KPack> a_thread_vec;
vector_type<ComputeDataType, KPack> b_thread_vec;
using mfma_input_type =
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
static_for<0, KPack, 1>{}([&](auto ik) {
a_thread_vec.template AsType<ComputeDataType>()(ik) =
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
make_tuple(a_future_tile_offset, I0, k0, ik))>{}];
b_thread_vec.template AsType<ComputeDataType>()(ik) =
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
make_tuple(b_future_tile_offset, I0, k0, ik))>{}];
});
using mfma_input_type =
typename vector_type<ComputeDataType,
xdlops_gemm.K1PerXdlops>::type;
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(
Number<mfma_buf_offset>{}));
});
}
// Scale and accumulate the previous iteration's result
constexpr auto c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
c_thread_buf_per_scale
.GetVectorTypeReference(Number<scale_buf_offset>{})
.template AsType<pk_fma_type>()[t],
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
.template AsType<pk_fma_type>()[t]);
xdlops_gemm.template Run<>(
a_thread_vec.template AsType<mfma_input_type>(),
b_thread_vec.template AsType<mfma_input_type>(),
c_thread_buf_per_scale.GetVectorTypeReference(Number<buffer_offset>{}));
});
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
constexpr index_t c_offset =
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
c_thread_buf(Number<c_offset>{}) +=
c_thread_buf_per_scale.GetVectorTypeReference(Number<buffer_offset>{})
.template AsType<AccDataType>()(Number<t>{}) *
type_convert<AccDataType>(c_scale_thread_buf[m0]);
});
});
});

View File

@@ -231,24 +231,23 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
}
};
// constexpr index_t minimum_occupancy = [&]() {
// // if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
// // is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
// // {
// // // FIXME: many instances have many spills with occupancy > 1, a better
// solution
// // // needed to get best performance
// // return 1;
// // }
// // else
// {
// return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
// MPerBlock * NPerBlock / BlockSize > 64)
// ? 1
// : 2;
// }
// }();
constexpr index_t minimum_occupancy = 2;
constexpr index_t minimum_occupancy = [&]() {
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
{
// FIXME: many instances have many spills with occupancy > 1, a better solution
// needed to get best performance
return 1;
}
else
{
return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
MPerBlock * NPerBlock / BlockSize > 64)
? 1
: 2;
}
}();
// constexpr index_t minimum_occupancy = 2;
if(has_main_k_block_loop)
{