mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Fused attention (#345)
* initial stub for gemm_gemm_xdl_cshuffle * set up example code * compiles * prevent integer overflow * harmonize interface between ref_gemm and ref_batched_gemm * batched_gemm_gemm * fix example * host tensor gen: diagonal pattern in lowest two-dimensions only * make c descriptors containing only integral constants * clean up * add BlockwiseGemmXdlops_v2 while exploring an unified approach * implement proper interface * tidy up example * fix compilation warnings * coarsely controlled 2nd gemm padding * remove rocm-cmake's hard requirement for certain revision * clang-format * resolve merge conflict * fix compilation error on gfx10 * adds acc0 elementwise op to interface * attention host validation * add blockwsie softmax v1 * iteratively update softmax+gemm * transpose both gemm0 and gemm1 xdl output so as to avoid broadcasting softmax max/sum * add init method for easier debugging * do away with manual thread cluster calculation * generalize blockwise softmax interface * row-wise softmax sum & max * format * rename to DeviceBatchedGemmSoftmaxGemm * add gemm_softmax_gemm instances and tests * comment Co-authored-by: ltqin <letao.qin@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com>
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/utility/sequence_helper.hpp"
|
||||
#include "ck/tensor_description/multi_index_transform.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -159,6 +160,12 @@ struct TensorDescriptor
|
||||
return transforms_[Number<itran>{}].GetUpperLengths()[Number<idim_up>{}];
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetLengths() const
|
||||
{
|
||||
// FIXME: use Tuple of reference instead
|
||||
return generate_sequence_v2([&](auto I) { return GetLength(I); }, Number<ndim_visible_>{});
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetElementSize() const { return element_size_; }
|
||||
|
||||
__host__ __device__ constexpr auto GetElementSpaceSize() const { return element_space_size_; }
|
||||
|
||||
Reference in New Issue
Block a user