diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp index 8ce1d6c6c7..77bfbd83cc 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp @@ -108,43 +108,9 @@ int main(int argc, char* argv[]) k_buf.ToDevice(k_host.mData.data()); v_buf.ToDevice(v_host.mData.data()); - constexpr ck_tile::index_t kM0PerBlock = 128; - constexpr ck_tile::index_t kN0PerBlock = 128; - constexpr ck_tile::index_t kK0PerBlock = 32; - constexpr ck_tile::index_t kN1PerBlock = 128; - constexpr ck_tile::index_t kK1PerBlock = 32; - constexpr ck_tile::index_t kBlockSize = 256; - constexpr ck_tile::index_t kHeadDim = 128; - - ck_tile::index_t kGridSize = Batch * (M0 / kM0PerBlock) * (N1 / kN1PerBlock); - - std::cout << "grid size " << kGridSize << std::endl; - - constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; - constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; - - float ave_time = ck_tile::launch_kernel(ck_tile::stream_config{nullptr, true}, - ck_tile::make_kernel( - ck_tile::FlashAttentionFwd{}, - kGridSize, - kBlockSize, - 0, + // Construct the FlashAttnArgs object with your arguments + ck_tile::FlashAttnArgs flash_attention_args { static_cast(q_buf.GetDeviceBuffer()), static_cast(k_buf.GetDeviceBuffer()), static_cast(v_buf.GetDeviceBuffer()), @@ -154,14 +120,25 @@ int main(int argc, char* argv[]) K0, N1, Batch, - K0, // StrideQ - K0, // StrideK - N0, // StrideV - N1, // StrideO - M0 * K0, // BatchStrideQ - N0 * K0, // BatchStrideK - N1 * N0, // BatchStrideV - M0 * N1)); // BatchStrideO + K0, // strideQ + K0, // strideK + N0, // strideV + N1, // strideO + M0 * K0, // batchStrideQ + N0 * K0, // batchStrideK + N1 * N0, // batchStrideV + M0 * N1 // batchStrideO + }; + + float ave_time = ck_tile::flash_attention_fwd + (flash_attention_args, ck_tile::stream_config{nullptr, true}); // reference auto pass = true; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index caeeece8e9..01580b586d 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -15,6 +15,42 @@ namespace ck_tile { + +template +struct FlashAttnArgs +{ + // Pointers to device buffers for Q, K, V, O + QDataType* q_ptr; + KDataType* k_ptr; + VDataType* v_ptr; + ODataType* o_ptr; + + // Problem sizes + index_t M0; + index_t N0; + index_t K0; + index_t N1; + index_t Batch; + + // Strides within a batch + index_t strideQ; + index_t strideK; + index_t strideV; + index_t strideO; + + // Batch strides + index_t batchStrideQ; + index_t batchStrideK; + index_t batchStrideV; + index_t batchStrideO; +}; + + + + // S[M0, N0] = Q[M0, K0] * K[N0, K0] // P[M0, N0] = Softmax(S[M0, N0]) // O[M0, N1] = P[M0, N0] * V[N1, N0] @@ -106,4 +142,71 @@ struct FlashAttentionFwd } }; +// TODO: change to only dec +template +float flash_attention_fwd(const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) { + constexpr ck_tile::index_t kM0PerBlock = 128; + constexpr ck_tile::index_t kN0PerBlock = 128; + constexpr ck_tile::index_t kK0PerBlock = 32; + constexpr ck_tile::index_t kN1PerBlock = 128; + constexpr ck_tile::index_t kK1PerBlock = 32; + + constexpr ck_tile::index_t kBlockSize = 256; + constexpr ck_tile::index_t kHeadDim = 128; + + ck_tile::index_t kGridSize = a.Batch * (a.M0 / kM0PerBlock) * (a.N1 / kN1PerBlock); + + std::cout << "grid size " << kGridSize << std::endl; + + constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + + return ck_tile::launch_kernel(stream_config, + ck_tile::make_kernel( + ck_tile::FlashAttentionFwd{}, + kGridSize, + kBlockSize, + 0, + a.q_ptr, + a.k_ptr, + a.v_ptr, + a.o_ptr, + a.M0, + a.N0, + a.K0, + a.N1, + a.Batch, + a.strideQ, // StrideQ + a.strideK, // StrideK + a.strideV, // StrideV + a.strideO, // StrideO + a.batchStrideQ, // BatchStrideQ + a.batchStrideK, // BatchStrideK + a.batchStrideV, // BatchStrideV + a.batchStrideO)); // BatchStrideO +} + } // namespace ck_tile