diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 3be95fdbae..c9e1e30c57 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -4,7 +4,6 @@ #pragma once #include "blockgemm_pipeline_agmem_bgmem_creg_policy_impl.hpp" -#include "../../../example/ck_tile/99_toy_example/02_gemm/block_gemm_pipeline_agmem_bgmem_creg.hpp" namespace ck_tile { diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp index fba49f9de5..c797714421 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp @@ -36,25 +36,21 @@ int main(int argc, char* argv[]) ck_tile::index_t N1 = 128; // HeadDim ck_tile::index_t verification = 0; ck_tile::index_t init_method = 1; - [[maybe_unused]] ck_tile::index_t time_kernel = 0; - if(argc == 4) + if(argc == 3) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); + verification = std::stoi(argv[2]); } - - if(argc == 9) + else if(argc == 8) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); - Batch = std::stoi(argv[4]); - M0 = std::stoi(argv[5]); - N0 = std::stoi(argv[6]); - K0 = std::stoi(argv[7]); - N1 = std::stoi(argv[8]); + verification = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); } std::array q_lengths{Batch, M0, K0}; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index 830b2422b5..33d36954c0 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -38,6 +38,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -46,7 +48,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; @@ -180,6 +182,8 @@ struct BlockGemmARegBSmemCRegV1 std::is_same_v>, "wrong!"); + static_assert((BlockGemmShape::kM == BlockGemmShape::kN), "wrong!"); + constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; @@ -188,7 +192,7 @@ struct BlockGemmARegBSmemCRegV1 KPerBlock == BlockGemmShape::kK, "wrong!"); - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index fb1516eb52..6eee7f0d1e 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1DefaultPolicy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + if constexpr (kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); + } + else if constexpr (kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); + } + else if constexpr (kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 32dc09f95e..3924a66daf 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -10,10 +10,25 @@ namespace ck_tile { struct BlockGemmARegBSmemCRegV1K8Policy { - template + template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + if constexpr (kM0 == 64) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); + } + else if constexpr (kM0 == 32) + { + return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); + } + else if constexpr (kM0 == 128) + { + return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); + } + else + { + static_assert(false, "Unsupported configuration for warp execution."); + } } }; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 3be95fdbae..4dd0c9a1e0 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -62,11 +62,13 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto blockgemm = GetBlockGemm(); using BlockGemm = remove_cvref_t; + static_assert((Problem::BlockGemmShape::kM == Problem::BlockGemmShape::kN), "wrong!"); + constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kKPerBlock = AKDim; constexpr auto config = - BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); + BlockGemm::BlockGemmPolicy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp index 9b7c9b1c6c..48680a218a 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp @@ -36,25 +36,21 @@ int main(int argc, char* argv[]) ck_tile::index_t N1 = 128; // HeadDim ck_tile::index_t verification = 0; ck_tile::index_t init_method = 1; - [[maybe_unused]] ck_tile::index_t time_kernel = 0; - if(argc == 4) + if(argc == 3) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); + verification = std::stoi(argv[2]); } - - if(argc == 9) + else if(argc == 8) { init_method = std::stoi(argv[1]); - time_kernel = std::stoi(argv[2]); - verification = std::stoi(argv[3]); - Batch = std::stoi(argv[4]); - M0 = std::stoi(argv[5]); - N0 = std::stoi(argv[6]); - K0 = std::stoi(argv[7]); - N1 = std::stoi(argv[8]); + verification = std::stoi(argv[2]); + Batch = std::stoi(argv[3]); + M0 = std::stoi(argv[4]); + N0 = std::stoi(argv[5]); + K0 = std::stoi(argv[6]); + N1 = std::stoi(argv[7]); } std::array q_lengths{Batch, M0, K0}; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index 040133f8e9..00bc91cadc 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -11,16 +11,8 @@ import itertools import copy from dataclasses import dataclass -# def get_if_str(idx, total, last_else=True): -# if idx == 0: -# return 'if' -# elif idx < total - 1: -# return 'else if' -# else: -# return 'else' if last_else else 'else if' - def get_if_str(size_, total, last_else=True): - if size_ == "small": + if size_ == "head_dim_256_seq_4096": return 'if' else: return 'else if' @@ -39,13 +31,13 @@ template + index_t kBlockSize_ = 256, + index_t kHeadDim_ = 128, + index_t kM0PerBlock_ = 128, + index_t kN0PerBlock_ = 128, + index_t kK0PerBlock_ = 64, + index_t kN1PerBlock_ = 128, + index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_ { using SaccDataType = ck_tile::remove_cvref_t; @@ -62,7 +54,7 @@ struct flash_attention_fwd_traits_ static constexpr index_t kK1PerBlock = kK1PerBlock_; static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; }; @@ -70,13 +62,13 @@ template + ck_tile::index_t kBlockSize = 256, + ck_tile::index_t kHeadDim = 128, + ck_tile::index_t kM0PerBlock = 128, + ck_tile::index_t kN0PerBlock = 128, + ck_tile::index_t kK0PerBlock = 64, + ck_tile::index_t kN1PerBlock = 128, + ck_tile::index_t kK1PerBlock = 64> using traits_ = flash_attention_fwd_traits_; """ -# API_COMMON_HEADER = """ -# // SPDX-License-Identifier: MIT -# // Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. - -# #include -# #include "flash_attention_fwd.hpp" -# #include - -# #pragma once - -# using S = ck_tile::stream_config; -# using A = FlashAttnArgs; - -# {F_traits_define} - -# template -# float flash_attention_fwd_(const FlashAttnArgs& a, -# const ck_tile::stream_config& stream_config) {{ -# using SaccDataType = typename Traits_::SaccDataType; -# using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; -# using PDataType = typename Traits_::PDataType; -# using OaccDataType = typename Traits_::OaccDataType; - -# index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); - -# if(stream_config.log_level_ > 0) -# std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; - -# return ck_tile::launch_kernel(stream_config, -# ck_tile::make_kernel( -# ck_tile::FlashAttentionFwd{{}}, -# kGridSize, -# Traits_::kBlockSize, -# 0, -# a.q_ptr, -# a.k_ptr, -# a.v_ptr, -# a.o_ptr, -# a.M0, -# a.N0, -# a.K0, -# a.N1, -# a.Batch, -# a.strideQ, // StrideQ -# a.strideK, // StrideK -# a.strideV, // StrideV -# a.strideO, // StrideO -# a.batchStrideQ, // BatchStrideQ -# a.batchStrideK, // BatchStrideK -# a.batchStrideV, // BatchStrideV -# a.batchStrideO)); // BatchStrideO -# }} -# """ - API_BASE = """ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. @@ -204,14 +124,6 @@ template float flash_attention_fwd && std::is_same_v && std::is_same_v && std::is_same_v) {{ -# {F_per_size_case} -# }} -# """ -# API_PER_SIZE_CASE = """ {F_if} {F_SIZE_COND} {{ -# {F_inner_dispatch} -# }} -# """ API_INNER_CASE = """ {F_if} {F_VEC_COND} r = flash_attention_fwd_>(a, stream_config); """ @@ -320,13 +232,13 @@ template + index_t kBlockSize_ = 256, + index_t kHeadDim_ = 128, + index_t kM0PerBlock_ = 128, + index_t kN0PerBlock_ = 128, + index_t kK0PerBlock_ = 64, + index_t kN1PerBlock_ = 128, + index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_ {{ using SaccDataType = ck_tile::remove_cvref_t; @@ -456,36 +368,28 @@ float flash_attention_fwd_(const FlashAttnArgs