mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
Add bias scalar vectorload = 1 for gemm bias gemm (#791)
* first change bias load * add bias dim and scalervector parameter * make CDE0BlockTransferSrcVectorDim not work * changse toinstance * add limit for CDE0BlockTransferSrcScalarPerVector
This commit is contained in:
@@ -67,6 +67,8 @@ template <typename A0B0B1DataType, // FIXME: don't assume A0/B0/B1 have same dat
|
||||
index_t B0BlockTransferDstScalarPerVector_BK1,
|
||||
bool B0ThreadTransferSrcResetCoordinateAfterRun, // ignored
|
||||
index_t B0BlockLdsExtraN,
|
||||
index_t CDE0BlockTransferSrcVectorDim,
|
||||
index_t CDE0BlockTransferSrcScalarPerVector,
|
||||
typename B1BlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename B1BlockTransferThreadClusterArrangeOrder,
|
||||
typename B1BlockTransferSrcAccessOrder,
|
||||
@@ -710,13 +712,13 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
constexpr auto d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5 =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(I1, // MBlockId
|
||||
I1, // NBlockID
|
||||
I1, // MRepeat
|
||||
I1, // NRepeat
|
||||
I1, // MWaveId
|
||||
I1, // NWaveId
|
||||
I1, // MPerXdl
|
||||
I1, // NGroupNum
|
||||
I1, // NInputNum
|
||||
m0, // MRepeat
|
||||
n0, // NRepeat
|
||||
m1, // MWaveId
|
||||
n1, // NWaveId
|
||||
m2, // MPerXdl
|
||||
n2, // NGroupNum
|
||||
n3, // NInputNum
|
||||
n4)); // registerNum
|
||||
|
||||
auto d0s_thread_buf = generate_tuple(
|
||||
@@ -732,8 +734,9 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
const auto wave_id = GetGemm0WaveIdx();
|
||||
const auto wave_m_n_id = GetGemm0WaveMNIdx(wave_id[I2]); // I2: 0~63
|
||||
|
||||
constexpr auto acc0_thread_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(Number<Gemm0MXdlPerWave>{}, Number<Gemm0NXdlPerWave>{}, n2, n4));
|
||||
static_assert(CDE0BlockTransferSrcScalarPerVector <= n4,
|
||||
"vector load must be not greater than n4");
|
||||
static_assert(n4 % CDE0BlockTransferSrcScalarPerVector == 0);
|
||||
|
||||
auto d0s_threadwise_copy = generate_tuple(
|
||||
[&](auto i) {
|
||||
@@ -742,10 +745,19 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
A0B0B1DataType,
|
||||
decltype(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i]),
|
||||
decltype(d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5),
|
||||
Sequence<I1, I1, I1, I1, I1, I1, I1, I1, I1, n4>,
|
||||
Sequence<I1, // MBlockId
|
||||
I1, // NBlockID
|
||||
m0, // MRepeat
|
||||
n0, // NRepeat
|
||||
m1, // MWaveId
|
||||
n1, // NWaveId
|
||||
m2, // MPerXdl
|
||||
n2, // NGroupNum
|
||||
n3, // NInputNum
|
||||
n4>,
|
||||
Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>,
|
||||
9,
|
||||
n4,
|
||||
9, // CDE0BlockTransferSrcVectorDim
|
||||
CDE0BlockTransferSrcScalarPerVector,
|
||||
1,
|
||||
false>(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(block_work_idx[I0], // MBlockId
|
||||
@@ -898,66 +910,42 @@ struct GridwiseBatchedGemmMultipleDGemmMultipleD_Xdl_CShuffle
|
||||
blockwise_gemm0,
|
||||
acc0_thread_buf,
|
||||
num_k_block_main_loop);
|
||||
// bias+gelu
|
||||
// multiple d
|
||||
if constexpr(NumD0Tensor)
|
||||
{
|
||||
static_for<0, Gemm0MXdlPerWave, 1>{}([&](auto mr) {
|
||||
static_for<0, Gemm0NXdlPerWave, 1>{}([&](auto nr) {
|
||||
static_for<0, n2, 1>{}([&](auto groupid) {
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).Run(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
d0s_grid_buf[i],
|
||||
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
d0s_thread_buf(i));
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).Run(d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
d0s_grid_buf[i],
|
||||
d0_thread_desc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0, I0, I0, I0, I0),
|
||||
d0s_thread_buf(i));
|
||||
});
|
||||
static_for<0, m0 * n0 * n2 * n4, 1>{}([&](auto i) {
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& { return d0s_thread_buf[iSrc][i]; },
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
static_for<0, n4, 1>{}([&](auto i) {
|
||||
constexpr index_t c_offset = acc0_thread_desc.CalculateOffset(
|
||||
make_tuple(mr, nr, groupid, i));
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto) -> auto& { return acc0_thread_buf(i); },
|
||||
Number<2>{});
|
||||
|
||||
// get reference to src data
|
||||
const auto src_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto iSrc) -> const auto& {
|
||||
return d0s_thread_buf[iSrc][i];
|
||||
},
|
||||
Number<NumD0Tensor>{});
|
||||
|
||||
// get reference to dst data
|
||||
auto dst_data_refs = generate_tie(
|
||||
// return type should be lvalue
|
||||
[&](auto) -> auto& {
|
||||
return acc0_thread_buf(Number<c_offset>{});
|
||||
},
|
||||
Number<2>{});
|
||||
|
||||
unpack2(cde0_element_op, dst_data_refs, src_data_refs);
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(0, 0, 0, 0, 0, 0, 0, 1, 0, 0));
|
||||
});
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(0, 0, 0, 1, 0, 0, 0, -n2.value, 0, 0));
|
||||
});
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(0, 0, 1, -Gemm0NXdlPerWave, 0, 0, 0, 0, 0, 0));
|
||||
});
|
||||
unpack2(cde0_element_op, dst_data_refs, src_data_refs);
|
||||
});
|
||||
static_for<0, NumD0Tensor, 1>{}([&](auto i) {
|
||||
d0s_threadwise_copy(i).MoveSrcSliceWindow(
|
||||
d0s_griddesc_m0_n0_m1_n1_m2_n2_m3_n3_n4_n5[i],
|
||||
make_multi_index(0, 1, -Gemm0MXdlPerWave, 0, 0, 0, 0, 0, 0, 0));
|
||||
make_multi_index(0, 1, 0, 0, 0, 0, 0, 0, 0, 0));
|
||||
});
|
||||
}
|
||||
else
|
||||
{
|
||||
static_for<0, acc0_thread_buf.Size(), 1>{}(
|
||||
[&](auto i) { cde0_element_op(acc_thread_buf(i), acc0_thread_buf[i]); });
|
||||
}
|
||||
// gemm1
|
||||
{
|
||||
// TODO: explore using dynamic buffer for a1 thread buffer
|
||||
|
||||
Reference in New Issue
Block a user