mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Fused attention (#345)
* initial stub for gemm_gemm_xdl_cshuffle * set up example code * compiles * prevent integer overflow * harmonize interface between ref_gemm and ref_batched_gemm * batched_gemm_gemm * fix example * host tensor gen: diagonal pattern in lowest two-dimensions only * make c descriptors containing only integral constants * clean up * add BlockwiseGemmXdlops_v2 while exploring an unified approach * implement proper interface * tidy up example * fix compilation warnings * coarsely controlled 2nd gemm padding * remove rocm-cmake's hard requirement for certain revision * clang-format * resolve merge conflict * fix compilation error on gfx10 * adds acc0 elementwise op to interface * attention host validation * add blockwsie softmax v1 * iteratively update softmax+gemm * transpose both gemm0 and gemm1 xdl output so as to avoid broadcasting softmax max/sum * add init method for easier debugging * do away with manual thread cluster calculation * generalize blockwise softmax interface * row-wise softmax sum & max * format * rename to DeviceBatchedGemmSoftmaxGemm * add gemm_softmax_gemm instances and tests * comment Co-authored-by: ltqin <letao.qin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -1145,9 +1145,22 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
src_desc, src_data_coord);
|
||||
|
||||
// copy data from src_buf into src_tmp_vector
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
|
||||
if constexpr(SrcBuffer::IsDynamicBuffer())
|
||||
{
|
||||
src_tmp_vector.template AsType<src_vector_t>()(Number<0>{}) =
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
|
||||
}
|
||||
else if constexpr(SrcBuffer::IsStaticBuffer())
|
||||
{
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx +
|
||||
i * src_scalar_step_in_vector);
|
||||
|
||||
// apply type convert
|
||||
src_tmp_vector.template AsType<SrcData>()(i) = src_buf[Number<src_offset>{}];
|
||||
});
|
||||
}
|
||||
// copy data from src_tmp_vector to dst_tmp_vector (data cast data from SrcData to
|
||||
// DstData)
|
||||
vector_type_maker_t<DstData, SrcScalarPerVector> dst_tmp_vector;
|
||||
@@ -1184,4 +1197,101 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
SrcCoord src_ref_coord_;
|
||||
};
|
||||
|
||||
// Do NOT involve any tensor coordinates with StaticBuffer
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename ElementwiseOperation,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_StaticToStatic
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr ThreadwiseTensorSliceTransfer_StaticToStatic(
|
||||
const ElementwiseOperation& element_op)
|
||||
: element_op_{element_op}
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
static_assert(SliceLengths::At(Number<DstVectorDim>{}) % DstScalarPerVector == 0,
|
||||
"wrong! Not divisible");
|
||||
}
|
||||
|
||||
template <typename SrcSliceOriginIdx,
|
||||
typename DstSliceOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcSliceOriginIdx&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstSliceOriginIdx&,
|
||||
DstBuffer& dst_buf)
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc need to known at compile-time");
|
||||
|
||||
static_assert(is_known_at_compile_time<remove_cvref_t<SrcSliceOriginIdx>>::value &&
|
||||
is_known_at_compile_time<remove_cvref_t<DstSliceOriginIdx>>::value,
|
||||
"wrong! SliceOrigin need to known at compile-time");
|
||||
|
||||
static_assert(SrcBuffer::IsStaticBuffer() && DstBuffer::IsStaticBuffer(),
|
||||
"wrong! Buffer need to be StaticBuffer");
|
||||
|
||||
// SrcDesc and src_slice_origin_idx are known at compile-time
|
||||
constexpr auto src_desc = remove_cvref_t<SrcDesc>{};
|
||||
constexpr auto dst_desc = remove_cvref_t<DstDesc>{};
|
||||
constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{});
|
||||
constexpr auto dst_slice_origin_idx = to_multi_index(DstSliceOriginIdx{});
|
||||
|
||||
// scalar per access on each dim
|
||||
constexpr auto dst_scalar_per_access = generate_sequence(
|
||||
detail::lambda_scalar_per_access<DstVectorDim, DstScalarPerVector>{}, Number<nDim>{});
|
||||
|
||||
constexpr auto dst_scalar_step_in_vector =
|
||||
generate_sequence(detail::lambda_scalar_step_in_vector<DstVectorDim>{}, Number<nDim>{});
|
||||
|
||||
using SpaceFillingCurve = SpaceFillingCurve<SliceLengths,
|
||||
DimAccessOrder,
|
||||
remove_cv_t<decltype(dst_scalar_per_access)>>;
|
||||
|
||||
static_assert(DstScalarPerVector == SpaceFillingCurve::ScalarPerVector,
|
||||
"wrong!DstScalarPerVector != SpaceFillingCurve::ScalarPerVector");
|
||||
|
||||
constexpr auto num_access = SpaceFillingCurve::GetNumOfAccess();
|
||||
|
||||
static_for<0, num_access, 1>{}([&](auto idx_1d) {
|
||||
constexpr auto idx_md = SpaceFillingCurve::GetIndex(idx_1d);
|
||||
|
||||
// copy data from src_buf into dst_vector
|
||||
static_for<0, DstScalarPerVector, 1>{}([&](auto i) {
|
||||
constexpr index_t src_offset = src_desc.CalculateOffset(
|
||||
src_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_slice_origin_idx + idx_md + i * dst_scalar_step_in_vector);
|
||||
|
||||
SrcData v;
|
||||
|
||||
// apply element-wise operation
|
||||
element_op_(v, src_buf[Number<src_offset>{}]);
|
||||
|
||||
// apply type convert
|
||||
dst_buf(Number<dst_offset>{}) = type_convert<DstData>(v);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
ElementwiseOperation element_op_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
Reference in New Issue
Block a user