mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 22:22:27 +00:00
Generate output using Doxygen / Breathe (#598)
* Modify Doxygen config to pick up include directories recursively * Add DeviceMem struct to API Reference guide * Add classes that are used in Flash Attention kernel * Add a reference and config for generating bibliography Co-authored-by: Philip Maybank <Philip.Maybank@amd.com>
This commit is contained in:
@@ -622,11 +622,16 @@ constexpr auto BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector()
|
||||
}
|
||||
};
|
||||
|
||||
// Blockwise gemm supporting
|
||||
// 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
|
||||
// 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
|
||||
// source buffer
|
||||
// 3. configurable k index starting position and step size after each FMA/XDL instruction
|
||||
/**
|
||||
* @brief Blockwise gemm
|
||||
*
|
||||
* Supports
|
||||
* 1. regular XDL output M2_M3_M4_M2 and transposed XDL output M2_N2_N3_N4
|
||||
* 2. decoupled input tile descriptor and mma tile descriptor in order to support both vgpr and LDS
|
||||
* source buffer
|
||||
* 3. configurable k index starting position and step size after each FMA/XDL instruction
|
||||
*/
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
|
||||
@@ -12,6 +12,16 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* @brief Blockwise softmax
|
||||
*
|
||||
* @tparam BlockSize
|
||||
* @tparam AccDataType
|
||||
* @tparam ThreadMap_M_K
|
||||
* @tparam ThreadClusterDesc_M_K
|
||||
* @tparam ThreadSliceDesc_M_K
|
||||
* @tparam IgnoreNaN
|
||||
*/
|
||||
template <index_t BlockSize,
|
||||
typename AccDataType,
|
||||
typename ThreadMap_M_K, // thread_id to m_k
|
||||
|
||||
@@ -11,10 +11,15 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
/**
|
||||
* @brief Blockwise data transfer
|
||||
*
|
||||
* This version does following things to avoid scratch memory issue
|
||||
* 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
* 2. ThreadwiseTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
* 3. ThreadwiseTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
*
|
||||
*/
|
||||
template <typename ThreadGroup,
|
||||
typename SrcElementwiseOperation,
|
||||
typename DstElementwiseOperation,
|
||||
|
||||
@@ -18,6 +18,10 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
/**
|
||||
* @brief Gridwise gemm + softmax + gemm fusion
|
||||
*
|
||||
*/
|
||||
template <typename FloatAB,
|
||||
typename FloatGemmAcc,
|
||||
typename FloatCShuffle,
|
||||
|
||||
@@ -1201,7 +1201,12 @@ struct ThreadwiseTensorSliceTransfer_v4
|
||||
SrcCoord src_ref_coord_;
|
||||
};
|
||||
|
||||
// Do NOT involve any tensor coordinates with StaticBuffer
|
||||
/**
|
||||
* @brief Threadwise data transfer
|
||||
*
|
||||
* Do NOT involve any tensor coordinates with StaticBuffer
|
||||
*
|
||||
*/
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
|
||||
Reference in New Issue
Block a user