mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
Another alternative, this is numerically correct, but lets compiler do more work to enable the double buffering
This commit is contained in:
@@ -121,7 +121,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
KPack,
|
||||
true>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -144,10 +143,9 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
using Base::AMmaKStride;
|
||||
using Base::BMmaKStride;
|
||||
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
@@ -331,19 +329,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
static_assert(CScaleThreadDesc{}.GetLength(Number<2>{}) == 1,
|
||||
"Pipeline v3 only support scaleblocksliceN=1");
|
||||
// assume kperblock = scaleblockk
|
||||
using AThreadBufType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize()));
|
||||
using BThreadBufType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize()));
|
||||
AThreadBufType a_thread_buf{};
|
||||
BThreadBufType b_thread_buf{};
|
||||
|
||||
using AScaleBufferType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize()));
|
||||
using BScaleBufferType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize()));
|
||||
StaticallyIndexedArray<AScaleBufferType, Number<1>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<BScaleBufferType, Number<1>{}> b_scale_thread_bufs;
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
@@ -359,7 +352,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
@@ -379,12 +372,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs(I0));
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
|
||||
// Local prefill 1
|
||||
@@ -403,7 +396,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
@@ -423,19 +416,17 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs(I0));
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Initialize C
|
||||
c_thread_buf.Clear();
|
||||
|
||||
// Double register buffer for non-scaled gemm computation
|
||||
// 1. Reduce register pressure
|
||||
// 2. Decouple the dependency between mfma instruction and scale-fma instruction following.
|
||||
// Double buffer for c_thread_buf_per_scale to enable temporal decoupling
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
2,
|
||||
2, // Double buffer
|
||||
xdlops_gemm.GetRegSizePerXdlops(),
|
||||
true>
|
||||
c_thread_buf_per_scale;
|
||||
@@ -461,33 +452,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
});
|
||||
|
||||
// Clear the first MFMA buffer
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0).template AsType<AccDataType>()(
|
||||
Number<t>{}) = 0;
|
||||
});
|
||||
|
||||
// Fill first MFMA buffer with index I0, this output used in the first part of main loop for
|
||||
// scale-FMA
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) = a_thread_buf
|
||||
[Number<a_thread_desc_.CalculateOffset(make_tuple(I0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) = b_thread_buf
|
||||
[Number<b_thread_desc_.CalculateOffset(make_tuple(I0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(I0));
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
@@ -507,72 +471,26 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_bufs(I0));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs(I0));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[m0];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[m0];
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// Calculate buffer offsets using future tile approach
|
||||
constexpr auto mfma_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
constexpr auto scale_buf_offset =
|
||||
// Use double buffering with temporal offset to decouple MFMA and scaling
|
||||
constexpr auto buffer_offset =
|
||||
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
|
||||
// Calculate future tile data offsets
|
||||
constexpr auto a_future_tile_offset =
|
||||
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
|
||||
constexpr auto b_future_tile_offset =
|
||||
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
|
||||
|
||||
// Clear the MFMA output buffer for future tile
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<buffer_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
|
||||
// Compute MFMA for future tile
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
// Use future tile offsets for MFMA computation
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(a_future_tile_offset, I0, k0, ik))>{}];
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(b_future_tile_offset, I0, k0, ik))>{}];
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
@@ -583,24 +501,16 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<mfma_buf_offset>{}));
|
||||
Number<buffer_offset>{}));
|
||||
});
|
||||
|
||||
// Run the element-wise FMA with data from previous iteration buffer
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale // mfma output from previous iteration
|
||||
.GetVectorTypeReference(Number<scale_buf_offset>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec
|
||||
.template AsType<pk_fma_type>()[Number<0>{}], // scales c=a*b
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<buffer_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) *
|
||||
type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -625,90 +535,89 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) =
|
||||
a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// Compute scale values early to allow better instruction scheduling
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
|
||||
// Prefetch scale data early to overlap with MFMA computation
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_buf);
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
|
||||
if constexpr(NumKBlockPerScale == 1)
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<2>{}));
|
||||
}
|
||||
else
|
||||
{
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(
|
||||
a_scale_grid_desc, a_scale_thread_copy_step.At(Number<1>{}));
|
||||
}
|
||||
|
||||
b_scale_thread_copy.Run(b_scale_grid_desc,
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_buf);
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[m0];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[m0];
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// Calculate buffer offsets using the same future tile approach
|
||||
constexpr auto mfma_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
constexpr auto scale_buf_offset =
|
||||
// Use double buffering with temporal offset in tail section as well
|
||||
constexpr auto buffer_offset =
|
||||
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
|
||||
// Calculate future tile data offsets
|
||||
constexpr auto a_future_tile_offset =
|
||||
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
|
||||
constexpr auto b_future_tile_offset =
|
||||
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<buffer_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
// Skip MFMA computation for the last tile to avoid out-of-bounds
|
||||
if constexpr(!((m0 == (MRepeat - 1)) && (n0 == (NRepeat - 1))))
|
||||
{
|
||||
// Clear the MFMA buffer for future tile computation
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
// Compute MFMA for future tile
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(a_future_tile_offset, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(b_future_tile_offset, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(
|
||||
Number<mfma_buf_offset>{}));
|
||||
});
|
||||
}
|
||||
|
||||
// Scale and accumulate the previous iteration's result
|
||||
constexpr auto c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<scale_buf_offset>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<buffer_offset>{}));
|
||||
});
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
c_thread_buf(Number<c_offset>{}) +=
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<buffer_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) *
|
||||
type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
@@ -231,24 +231,23 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3
|
||||
}
|
||||
};
|
||||
|
||||
// constexpr index_t minimum_occupancy = [&]() {
|
||||
// // if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
|
||||
// // is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
// // {
|
||||
// // // FIXME: many instances have many spills with occupancy > 1, a better
|
||||
// solution
|
||||
// // // needed to get best performance
|
||||
// // return 1;
|
||||
// // }
|
||||
// // else
|
||||
// {
|
||||
// return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
|
||||
// MPerBlock * NPerBlock / BlockSize > 64)
|
||||
// ? 1
|
||||
// : 2;
|
||||
// }
|
||||
// }();
|
||||
constexpr index_t minimum_occupancy = 2;
|
||||
constexpr index_t minimum_occupancy = [&]() {
|
||||
if constexpr(is_same_v<tensor_layout::gemm::ColumnMajor, ALayout> &&
|
||||
is_same_v<tensor_layout::gemm::RowMajor, BLayout>)
|
||||
{
|
||||
// FIXME: many instances have many spills with occupancy > 1, a better solution
|
||||
// needed to get best performance
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
return (BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave &&
|
||||
MPerBlock * NPerBlock / BlockSize > 64)
|
||||
? 1
|
||||
: 2;
|
||||
}
|
||||
}();
|
||||
// constexpr index_t minimum_occupancy = 2;
|
||||
|
||||
if(has_main_k_block_loop)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user