mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
Fused attention (#345)
* initial stub for gemm_gemm_xdl_cshuffle * set up example code * compiles * prevent integer overflow * harmonize interface between ref_gemm and ref_batched_gemm * batched_gemm_gemm * fix example * host tensor gen: diagonal pattern in lowest two-dimensions only * make c descriptors containing only integral constants * clean up * add BlockwiseGemmXdlops_v2 while exploring an unified approach * implement proper interface * tidy up example * fix compilation warnings * coarsely controlled 2nd gemm padding * remove rocm-cmake's hard requirement for certain revision * clang-format * resolve merge conflict * fix compilation error on gfx10 * adds acc0 elementwise op to interface * attention host validation * add blockwsie softmax v1 * iteratively update softmax+gemm * transpose both gemm0 and gemm1 xdl output so as to avoid broadcasting softmax max/sum * add init method for easier debugging * do away with manual thread cluster calculation * generalize blockwise softmax interface * row-wise softmax sum & max * format * rename to DeviceBatchedGemmSoftmaxGemm * add gemm_softmax_gemm instances and tests * comment Co-authored-by: ltqin <letao.qin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -579,7 +579,11 @@ struct MfmaSelector
|
||||
static constexpr index_t GetK1PerXdlops() { return selected_mfma.k_per_blk; }
|
||||
};
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops, index_t KPack>
|
||||
template <typename base_type,
|
||||
index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t KPack,
|
||||
bool TransposeC = false>
|
||||
struct XdlopsGemm
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -612,6 +616,8 @@ struct XdlopsGemm
|
||||
static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk");
|
||||
}
|
||||
|
||||
// XDL output supporting C = A * B
|
||||
// M2_N2 -> M2_M3_M4_N2
|
||||
template <typename CDesc_M0_N0_M1_N1_M2_N2>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
@@ -627,10 +633,10 @@ struct XdlopsGemm
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(N1),
|
||||
make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk,
|
||||
mfma_instr.num_input_blks,
|
||||
mfma_instr.group_size)),
|
||||
make_pass_through_transform(mfma_instr.num_threads_per_blk)),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<mfma_instr.group_size>{})),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{})),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
@@ -645,6 +651,41 @@ struct XdlopsGemm
|
||||
Sequence<7>{}));
|
||||
}
|
||||
|
||||
// transposed XDL output supporting C' = B' * A'
|
||||
// M2_N2 -> M2_N2_N3_N4
|
||||
template <typename CDesc_M0_N0_M1_N1_M2_N2>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCDescriptor_M0_N0_M1_N1_M2_N2_N3_N4(const CDesc_M0_N0_M1_N1_M2_N2& c_desc_m0_n0_m1_n1_m2_n2)
|
||||
{
|
||||
const auto M0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I0);
|
||||
const auto N0 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I1);
|
||||
const auto M1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I2);
|
||||
const auto N1 = c_desc_m0_n0_m1_n1_m2_n2.GetLength(I3);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
c_desc_m0_n0_m1_n1_m2_n2,
|
||||
make_tuple(make_pass_through_transform(M0),
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(M1),
|
||||
make_pass_through_transform(N1),
|
||||
make_pass_through_transform(Number<mfma_instr.num_threads_per_blk>{}),
|
||||
make_unmerge_transform(make_tuple(Number<mfma_instr.num_groups_per_blk>{},
|
||||
Number<mfma_instr.num_input_blks>{},
|
||||
Number<mfma_instr.group_size>{}))),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<5, 6, 7>{}));
|
||||
}
|
||||
|
||||
template <typename CDesc_G_M0_N0_M1_N1_M2_N2>
|
||||
__host__ __device__ static constexpr auto MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2(
|
||||
const CDesc_G_M0_N0_M1_N1_M2_N2& c_desc_g_m0_n0_m1_n1_m2_n2)
|
||||
@@ -698,7 +739,16 @@ struct XdlopsGemm
|
||||
"base base_type must be double, float, half, bfloat16, and int8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
if constexpr(!TransposeC)
|
||||
{
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(
|
||||
p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
}
|
||||
else
|
||||
{
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(
|
||||
p_b_wave[k], p_a_wave[k], p_c_thread);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user