Refactor FlashAttnArgs usage for codegen

This commit is contained in:
Clement Lin
2025-04-11 13:46:25 +08:00
committed by Philip Maybank
parent 0cc5130818
commit e614acfdd8
2 changed files with 124 additions and 44 deletions

View File

@@ -108,43 +108,9 @@ int main(int argc, char* argv[])
k_buf.ToDevice(k_host.mData.data());
v_buf.ToDevice(v_host.mData.data());
constexpr ck_tile::index_t kM0PerBlock = 128;
constexpr ck_tile::index_t kN0PerBlock = 128;
constexpr ck_tile::index_t kK0PerBlock = 32;
constexpr ck_tile::index_t kN1PerBlock = 128;
constexpr ck_tile::index_t kK1PerBlock = 32;
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kHeadDim = 128;
ck_tile::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock);
std::cout << "grid size " << kGridSize << std::endl;
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
ck_tile::FlashAttentionFwd<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{},
kGridSize,
kBlockSize,
0,
// Construct the FlashAttnArgs object with your arguments
ck_tile::FlashAttnArgs<QDataType, KDataType, VDataType, ODataType> flash_attention_args {
static_cast<QDataType*>(q_buf.GetDeviceBuffer()),
static_cast<KDataType*>(k_buf.GetDeviceBuffer()),
static_cast<VDataType*>(v_buf.GetDeviceBuffer()),
@@ -154,14 +120,25 @@ int main(int argc, char* argv[])
K0,
N1,
Batch,
K0, // StrideQ
K0, // StrideK
N0, // StrideV
N1, // StrideO
M0 * K0, // BatchStrideQ
N0 * K0, // BatchStrideK
N1 * N0, // BatchStrideV
M0 * N1)); // BatchStrideO
K0, // strideQ
K0, // strideK
N0, // strideV
N1, // strideO
M0 * K0, // batchStrideQ
N0 * K0, // batchStrideK
N1 * N0, // batchStrideV
M0 * N1 // batchStrideO
};
float ave_time = ck_tile::flash_attention_fwd<QDataType,
QDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType>
(flash_attention_args, ck_tile::stream_config{nullptr, true});
// reference
auto pass = true;

View File

@@ -15,6 +15,42 @@
namespace ck_tile {
template <typename QDataType,
typename KDataType,
typename VDataType,
typename ODataType>
struct FlashAttnArgs
{
// Pointers to device buffers for Q, K, V, O
QDataType* q_ptr;
KDataType* k_ptr;
VDataType* v_ptr;
ODataType* o_ptr;
// Problem sizes
index_t M0;
index_t N0;
index_t K0;
index_t N1;
index_t Batch;
// Strides within a batch
index_t strideQ;
index_t strideK;
index_t strideV;
index_t strideO;
// Batch strides
index_t batchStrideQ;
index_t batchStrideK;
index_t batchStrideV;
index_t batchStrideO;
};
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// P[M0, N0] = Softmax(S[M0, N0])
// O[M0, N1] = P[M0, N0] * V[N1, N0]
@@ -106,4 +142,71 @@ struct FlashAttentionFwd
}
};
// TODO: change to only dec
template <typename QDataType,
typename KDataType,
typename VDataType,
typename SaccDataType,
typename SMPLComputeDataType,
typename PDataType,
typename OaccDataType,
typename ODataType>
float flash_attention_fwd(const FlashAttnArgs<QDataType, KDataType, VDataType, ODataType>& a,
const ck_tile::stream_config& stream_config) {
constexpr ck_tile::index_t kM0PerBlock = 128;
constexpr ck_tile::index_t kN0PerBlock = 128;
constexpr ck_tile::index_t kK0PerBlock = 32;
constexpr ck_tile::index_t kN1PerBlock = 128;
constexpr ck_tile::index_t kK1PerBlock = 32;
constexpr ck_tile::index_t kBlockSize = 256;
constexpr ck_tile::index_t kHeadDim = 128;
ck_tile::index_t kGridSize = a.Batch * (a.M0 / kM0PerBlock) * (a.N1 / kN1PerBlock);
std::cout << "grid size " << kGridSize << std::endl;
constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD
constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize;
constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock;
return ck_tile::launch_kernel(stream_config,
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
ck_tile::FlashAttentionFwd<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
PDataType,
OaccDataType,
ODataType,
kBlockSize,
kHeadDim,
kM0PerBlock,
kN0PerBlock,
kK0PerBlock,
kN1PerBlock,
kK1PerBlock>{},
kGridSize,
kBlockSize,
0,
a.q_ptr,
a.k_ptr,
a.v_ptr,
a.o_ptr,
a.M0,
a.N0,
a.K0,
a.N1,
a.Batch,
a.strideQ, // StrideQ
a.strideK, // StrideK
a.strideV, // StrideV
a.strideO, // StrideO
a.batchStrideQ, // BatchStrideQ
a.batchStrideK, // BatchStrideK
a.batchStrideV, // BatchStrideV
a.batchStrideO)); // BatchStrideO
}
} // namespace ck_tile