mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
Refactor FlashAttnArgs usage for codegen
This commit is contained in:
committed by
Philip Maybank
parent
0cc5130818
commit
e614acfdd8
@@ -108,43 +108,9 @@ int main(int argc, char* argv[])
|
||||
k_buf.ToDevice(k_host.mData.data());
|
||||
v_buf.ToDevice(v_host.mData.data());
|
||||
|
||||
constexpr ck_tile::index_t kM0PerBlock = 128;
|
||||
constexpr ck_tile::index_t kN0PerBlock = 128;
|
||||
constexpr ck_tile::index_t kK0PerBlock = 32;
|
||||
constexpr ck_tile::index_t kN1PerBlock = 128;
|
||||
constexpr ck_tile::index_t kK1PerBlock = 32;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kHeadDim = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
ck_tile::FlashAttentionFwd<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
// Construct the FlashAttnArgs object with your arguments
|
||||
ck_tile::FlashAttnArgs<QDataType, KDataType, VDataType, ODataType> flash_attention_args {
|
||||
static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
|
||||
static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
|
||||
static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
|
||||
@@ -154,14 +120,25 @@ int main(int argc, char* argv[])
|
||||
K0,
|
||||
N1,
|
||||
Batch,
|
||||
K0, // StrideQ
|
||||
K0, // StrideK
|
||||
N0, // StrideV
|
||||
N1, // StrideO
|
||||
M0 * K0, // BatchStrideQ
|
||||
N0 * K0, // BatchStrideK
|
||||
N1 * N0, // BatchStrideV
|
||||
M0 * N1)); // BatchStrideO
|
||||
K0, // strideQ
|
||||
K0, // strideK
|
||||
N0, // strideV
|
||||
N1, // strideO
|
||||
M0 * K0, // batchStrideQ
|
||||
N0 * K0, // batchStrideK
|
||||
N1 * N0, // batchStrideV
|
||||
M0 * N1 // batchStrideO
|
||||
};
|
||||
|
||||
float ave_time = ck_tile::flash_attention_fwd<QDataType,
|
||||
QDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType>
|
||||
(flash_attention_args, ck_tile::stream_config{nullptr, true});
|
||||
|
||||
// reference
|
||||
auto pass = true;
|
||||
|
||||
@@ -15,6 +15,42 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename ODataType>
|
||||
struct FlashAttnArgs
|
||||
{
|
||||
// Pointers to device buffers for Q, K, V, O
|
||||
QDataType* q_ptr;
|
||||
KDataType* k_ptr;
|
||||
VDataType* v_ptr;
|
||||
ODataType* o_ptr;
|
||||
|
||||
// Problem sizes
|
||||
index_t M0;
|
||||
index_t N0;
|
||||
index_t K0;
|
||||
index_t N1;
|
||||
index_t Batch;
|
||||
|
||||
// Strides within a batch
|
||||
index_t strideQ;
|
||||
index_t strideK;
|
||||
index_t strideV;
|
||||
index_t strideO;
|
||||
|
||||
// Batch strides
|
||||
index_t batchStrideQ;
|
||||
index_t batchStrideK;
|
||||
index_t batchStrideV;
|
||||
index_t batchStrideO;
|
||||
};
|
||||
|
||||
|
||||
|
||||
|
||||
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
|
||||
// P[M0, N0] = Softmax(S[M0, N0])
|
||||
// O[M0, N1] = P[M0, N0] * V[N1, N0]
|
||||
@@ -106,4 +142,71 @@ struct FlashAttentionFwd
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: change to only dec
|
||||
template <typename QDataType,
|
||||
typename KDataType,
|
||||
typename VDataType,
|
||||
typename SaccDataType,
|
||||
typename SMPLComputeDataType,
|
||||
typename PDataType,
|
||||
typename OaccDataType,
|
||||
typename ODataType>
|
||||
float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
|
||||
const ck_tile::stream_config& stream_config) {
|
||||
constexpr ck_tile::index_t kM0PerBlock = 128;
|
||||
constexpr ck_tile::index_t kN0PerBlock = 128;
|
||||
constexpr ck_tile::index_t kK0PerBlock = 32;
|
||||
constexpr ck_tile::index_t kN1PerBlock = 128;
|
||||
constexpr ck_tile::index_t kK1PerBlock = 32;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 256;
|
||||
constexpr ck_tile::index_t kHeadDim = 128;
|
||||
|
||||
ck_tile::index_t kGridSize = a.Batch * (a.M0 / kM0PerBlock) * (a.N1 / kN1PerBlock);
|
||||
|
||||
std::cout << "grid size " << kGridSize << std::endl;
|
||||
|
||||
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
|
||||
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
|
||||
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
|
||||
|
||||
return ck_tile::launch_kernel(stream_config,
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
ck_tile::FlashAttentionFwd<QDataType,
|
||||
KDataType,
|
||||
VDataType,
|
||||
SaccDataType,
|
||||
SMPLComputeDataType,
|
||||
PDataType,
|
||||
OaccDataType,
|
||||
ODataType,
|
||||
kBlockSize,
|
||||
kHeadDim,
|
||||
kM0PerBlock,
|
||||
kN0PerBlock,
|
||||
kK0PerBlock,
|
||||
kN1PerBlock,
|
||||
kK1PerBlock>{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
a.q_ptr,
|
||||
a.k_ptr,
|
||||
a.v_ptr,
|
||||
a.o_ptr,
|
||||
a.M0,
|
||||
a.N0,
|
||||
a.K0,
|
||||
a.N1,
|
||||
a.Batch,
|
||||
a.strideQ, // StrideQ
|
||||
a.strideK, // StrideK
|
||||
a.strideV, // StrideV
|
||||
a.strideO, // StrideO
|
||||
a.batchStrideQ, // BatchStrideQ
|
||||
a.batchStrideK, // BatchStrideK
|
||||
a.batchStrideV, // BatchStrideV
|
||||
a.batchStrideO)); // BatchStrideO
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user