mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
WIP
This commit is contained in:
@@ -121,7 +121,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
KPack,
|
||||
true>;
|
||||
using Base::I0;
|
||||
using Base::I1;
|
||||
using Base::KRepeat;
|
||||
using Base::xdlops_gemm;
|
||||
using typename Base::HotLoopInstList;
|
||||
@@ -147,6 +146,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr index_t HotloopLocalBufSwitch = MRepeat % 2 == 0 ? 0 : 1;
|
||||
|
||||
__host__ static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
@@ -334,16 +334,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, ComputeDataType>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize());
|
||||
auto b_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize());
|
||||
using AScaleBufferType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
a_scale_thread_desc.GetElementSpaceSize()));
|
||||
using BScaleBufferType = decltype(make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
b_scale_thread_desc.GetElementSpaceSize()));
|
||||
StaticallyIndexedArray<AScaleBufferType, Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<BScaleBufferType, Number<2>{}> b_scale_thread_bufs;
|
||||
auto c_scale_thread_buf = make_static_buffer<AddressSpaceEnum::Vgpr, AccDataType>(
|
||||
c_scale_thread_desc.GetElementSpaceSize());
|
||||
|
||||
StaticallyIndexedArray<decltype(a_scale_thread_buf), Number<2>{}> a_scale_thread_bufs;
|
||||
StaticallyIndexedArray<decltype(b_scale_thread_buf), Number<2>{}> b_scale_thread_bufs;
|
||||
|
||||
// Global prefetch 1
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
@@ -381,12 +380,12 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[I0][m0] * b_scale_thread_bufs[I0][I0];
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[m0] * b_scale_thread_buf[I0];
|
||||
});
|
||||
|
||||
// Local prefill 1
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(I0));
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(I0));
|
||||
|
||||
// Global prefetch 2
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
@@ -400,7 +399,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_scale_grid_buf,
|
||||
a_scale_thread_desc,
|
||||
make_tuple(m0, I0),
|
||||
a_scale_thread_bufs(I1));
|
||||
a_scale_thread_bufs(I0));
|
||||
a_scale_thread_copy.MoveSrcSliceWindow(a_scale_grid_desc,
|
||||
a_scale_thread_copy_step.At(Number<0>{}));
|
||||
});
|
||||
@@ -420,7 +419,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs(I1));
|
||||
b_scale_thread_bufs(I0));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
@@ -429,7 +428,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
|
||||
// Double register buffer for non-scaled gemm computation
|
||||
// 1. Reduce register pressure
|
||||
// 2. Decouple the dependency between mfma instruction and scale-fma instruction
|
||||
// 2. Decouple the dependency between mfma instruction and scale-fma instruction following.
|
||||
StaticBufferTupleOfVector<AddressSpaceEnum::Vgpr,
|
||||
AccDataType,
|
||||
2,
|
||||
@@ -458,6 +457,33 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
|
||||
// Fill first MFMA buffer
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(I0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(I0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
});
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
@@ -468,8 +494,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
{
|
||||
auto LoopFunc = [&](auto mfma_reg_buf, auto local_read_buf) {
|
||||
block_sync_lds();
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf);
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf);
|
||||
a_blockwise_copy.RunWrite(a_block_desc, a_block_buf.At(local_read_buf));
|
||||
b_blockwise_copy.RunWrite(b_block_desc, b_block_buf.At(local_read_buf));
|
||||
|
||||
a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf);
|
||||
b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf);
|
||||
@@ -477,88 +503,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step);
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) = c_scale_thread_buf[m0];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) = c_scale_thread_buf[m0];
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto mfma_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
constexpr auto scale_buf_offset =
|
||||
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
|
||||
// Clear buffer for new MFMA computation
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
|
||||
});
|
||||
|
||||
// Apply scaling with packed FMA and accumulate to main buffer
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<scale_buf_offset>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_scale_thread_copy.Run(a_scale_grid_desc,
|
||||
a_scale_grid_buf,
|
||||
@@ -584,15 +528,114 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
b_scale_grid_buf,
|
||||
b_scale_thread_desc,
|
||||
make_tuple(I0, I0),
|
||||
b_scale_thread_bufs(local_read_buf));
|
||||
b_scale_thread_buf(local_read_buf));
|
||||
|
||||
b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, b_scale_thread_copy_step);
|
||||
|
||||
// Update scales for next iteration using the loaded values
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_bufs[mfma_reg_buf][m0] * b_scale_thread_bufs[mfma_reg_buf][I0];
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) =
|
||||
c_scale_thread_buf[m0];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) =
|
||||
c_scale_thread_buf[m0];
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// Compute offsets
|
||||
constexpr auto mfma_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
constexpr auto scale_buf_offset =
|
||||
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
|
||||
constexpr auto a_local_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) / NRepeat;
|
||||
constexpr auto b_local_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % (MRepeat * NRepeat)) % NRepeat;
|
||||
|
||||
constexpr auto local_buf_id =
|
||||
Number<mfma_reg_buf ^
|
||||
((m0 * NRepeat + n0 + 1) / (MRepeat * NRepeat))>{};
|
||||
|
||||
// Clear the current mfma output buffer
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
constexpr auto m0_offset = (a_local_buf_offset + HotloopLocalBufSwitch * mfma_reg_buf) % 2;
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[local_buf_id][Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0_offset, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[local_buf_id][Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(b_local_buf_offset, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType,
|
||||
xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
|
||||
});
|
||||
// Apply scaling
|
||||
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
// constexpr index_t c_offset =
|
||||
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
// c_thread_buf(Number<c_offset>{}) +=
|
||||
// c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
// .template AsType<AccDataType>()[Number<t>{}] *
|
||||
// type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
// });
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<scale_buf_offset>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
};
|
||||
|
||||
block_sync_lds();
|
||||
static_for<0, KRepeat, 1>{}([&](auto k) {
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
a_thread_copy_.Run(a_block_desc_m0_m1_m2_k,
|
||||
make_tuple(m0, I0, I0, Number<k * AMmaKStride>{}),
|
||||
a_block_buf.At(mfma_reg_buf),
|
||||
a_thread_desc_,
|
||||
make_tuple(m0, I0, k, I0),
|
||||
a_thread_buf);
|
||||
});
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
b_thread_copy_.Run(b_block_desc_n0_n1_n2_k,
|
||||
make_tuple(n0, I0, I0, Number<k * BMmaKStride>{}),
|
||||
b_block_buf.At(mfma_reg_buf),
|
||||
b_thread_desc_,
|
||||
make_tuple(n0, I0, k, I0),
|
||||
b_thread_buf);
|
||||
});
|
||||
});
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
c_scale_thread_buf(m0) = a_scale_thread_buf[mfma_reg_buf][m0] *
|
||||
b_scale_thread_buf[mfma_reg_buf][I0];
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
LoopFunc(I0, I1);
|
||||
LoopFunc(I1, I0);
|
||||
@@ -602,65 +645,45 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Full)
|
||||
{
|
||||
static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
vector_type<AccDataType, 2> c_scale_thread_vec;
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<0>{}) = c_scale_thread_buf[m0];
|
||||
c_scale_thread_vec.template AsType<AccDataType>()(Number<1>{}) = c_scale_thread_buf[m0];
|
||||
// if constexpr(TailNum == TailNumber::Full)
|
||||
// {
|
||||
// static_for<0, MRepeat, 1>{}([&](auto m0) {
|
||||
// static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
// c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
// .template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
// });
|
||||
// static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
// vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
// vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, NRepeat, 1>{}([&](auto n0) {
|
||||
constexpr auto mfma_buf_offset =
|
||||
((m0 * NRepeat + n0 + 1) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
constexpr auto scale_buf_offset =
|
||||
((m0 * NRepeat + n0) % 2) * xdlops_gemm.GetRegSizePerXdlops();
|
||||
// static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
// a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
// a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
// make_tuple(m0, I0, k0, ik))>{}];
|
||||
// b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
// b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
// make_tuple(n0, I0, k0, ik))>{}];
|
||||
// });
|
||||
|
||||
// Clear buffer for new MFMA computation
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{})
|
||||
.template AsType<AccDataType>()(Number<t>{}) = 0;
|
||||
});
|
||||
// using mfma_input_type =
|
||||
// typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
vector_type<ComputeDataType, KPack> a_thread_vec;
|
||||
vector_type<ComputeDataType, KPack> b_thread_vec;
|
||||
|
||||
static_for<0, KPack, 1>{}([&](auto ik) {
|
||||
a_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
a_thread_buf[Number<a_thread_desc_.CalculateOffset(
|
||||
make_tuple(m0, I0, k0, ik))>{}];
|
||||
b_thread_vec.template AsType<ComputeDataType>()(ik) =
|
||||
b_thread_buf[Number<b_thread_desc_.CalculateOffset(
|
||||
make_tuple(n0, I0, k0, ik))>{}];
|
||||
});
|
||||
|
||||
using mfma_input_type =
|
||||
typename vector_type<ComputeDataType, xdlops_gemm.K1PerXdlops>::type;
|
||||
|
||||
xdlops_gemm.template Run<>(
|
||||
a_thread_vec.template AsType<mfma_input_type>(),
|
||||
b_thread_vec.template AsType<mfma_input_type>(),
|
||||
c_thread_buf_per_scale.GetVectorTypeReference(Number<mfma_buf_offset>{}));
|
||||
});
|
||||
|
||||
// Apply scaling with packed FMA and accumulate to main buffer
|
||||
constexpr index_t c_offset =
|
||||
c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0));
|
||||
|
||||
static_for<0, xdlops_gemm.GetRegSizePerXdlops() / 2, 1>{}([&](auto t) {
|
||||
using pk_fma_type = typename vector_type<AccDataType, 2>::type;
|
||||
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()(t) = __builtin_elementwise_fma(
|
||||
c_thread_buf_per_scale
|
||||
.GetVectorTypeReference(Number<scale_buf_offset>{})
|
||||
.template AsType<pk_fma_type>()[t],
|
||||
c_scale_thread_vec.template AsType<pk_fma_type>()[Number<0>{}],
|
||||
c_thread_buf.GetVectorTypeReference(Number<c_offset>{})
|
||||
.template AsType<pk_fma_type>()[t]);
|
||||
});
|
||||
});
|
||||
});
|
||||
// xdlops_gemm.template Run<>(
|
||||
// a_thread_vec.template AsType<mfma_input_type>(),
|
||||
// b_thread_vec.template AsType<mfma_input_type>(),
|
||||
// c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}));
|
||||
// });
|
||||
// static_for<0, xdlops_gemm.GetRegSizePerXdlops(), 1>{}([&](auto t) {
|
||||
// constexpr index_t c_offset =
|
||||
// c_thread_desc_.CalculateOffset(make_tuple(m0, n0, t));
|
||||
// c_thread_buf(Number<c_offset>{}) +=
|
||||
// c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{})
|
||||
// .template AsType<AccDataType>()[Number<t>{}] *
|
||||
// type_convert<AccDataType>(c_scale_thread_buf[m0]);
|
||||
// });
|
||||
// });
|
||||
// });
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -973,8 +973,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
constexpr auto c_block_size =
|
||||
c_shuffle_block_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize();
|
||||
|
||||
return math::max((a_block_space_size_aligned * sizeof(LDSTypeA) +
|
||||
b_block_space_size_aligned * sizeof(LDSTypeB)),
|
||||
return math::max((2 * a_block_space_size_aligned * sizeof(LDSTypeA) +
|
||||
2 * b_block_space_size_aligned * sizeof(LDSTypeB)),
|
||||
c_block_size * sizeof(CShuffleDataType));
|
||||
}
|
||||
|
||||
@@ -1327,15 +1327,26 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_space_size_aligned = math::integer_least_multiple(
|
||||
a_block_desc_ak0_m_ak1.GetElementSpaceSize(), max_lds_align);
|
||||
constexpr auto b_block_space_size_aligned = math::integer_least_multiple(
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
// Cast after lds
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
auto a_block_buf_1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
auto a_block_buf_2 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeA*>(p_shared) + a_block_space_size_aligned, a_block_desc_ak0_m_ak1.GetElementSpaceSize());
|
||||
auto a_block_bufs = make_tuple(a_block_buf_1, a_block_buf_2);
|
||||
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
auto b_block_buf_1 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeB*>(p_shared) +
|
||||
a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
|
||||
2 * a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB),
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_buf_2 = make_dynamic_buffer<AddressSpaceEnum::Lds>(
|
||||
static_cast<LDSTypeB*>(p_shared) +
|
||||
2 * a_block_space_size_aligned * sizeof(LDSTypeA) / sizeof(LDSTypeB) +
|
||||
b_block_space_size_aligned,
|
||||
b_block_desc_bk0_n_bk1.GetElementSpaceSize());
|
||||
auto b_block_bufs = make_tuple(b_block_buf_1, b_block_buf_2);
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock / BK1Number, 0, 0);
|
||||
@@ -1410,13 +1421,13 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3
|
||||
a_block_desc_ak0_m_ak1,
|
||||
a_blockwise_copy,
|
||||
a_grid_buf,
|
||||
a_block_buf,
|
||||
a_block_bufs,
|
||||
a_block_slice_copy_step,
|
||||
b_grid_desc_bk0_n_bk1,
|
||||
b_block_desc_bk0_n_bk1,
|
||||
b_blockwise_copy,
|
||||
b_grid_buf,
|
||||
b_block_buf,
|
||||
b_block_bufs,
|
||||
b_block_slice_copy_step,
|
||||
|
||||
c_scale_thread_desc,
|
||||
|
||||
Reference in New Issue
Block a user