diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index e6b1707851..129a4c5ed5 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -32,7 +32,8 @@ struct BlockGemmARegBSmemCRegV1 // B block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = + Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MWarp = config.template get<1>(); @@ -55,7 +56,8 @@ struct BlockGemmARegBSmemCRegV1 return b_block_dstr_encode; } - static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); template @@ -149,15 +151,14 @@ struct BlockGemmARegBSmemCRegV1 // C += A * B template - __device__ void operator() (CBlockTensor& c_block_tensor, + __device__ void operator()(CBlockTensor& c_block_tensor, const ABlockTensorTmp& a_block_tensor_tmp, const BLdsTile& b_block_tensor_tmp) const { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 6eee7f0d1e..8994689841 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - if constexpr (kM0 == 64) + if constexpr(kM0 == 64) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); } - else if constexpr (kM0 == 32) + else if constexpr(kM0 == 32) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); } - else if constexpr (kM0 == 128) + else if constexpr(kM0 == 128) { return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 3924a66daf..e3f3fd0cd6 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1K8Policy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - if constexpr (kM0 == 64) + if constexpr(kM0 == 64) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); } - else if constexpr (kM0 == 32) + else if constexpr(kM0 == 32) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); } - else if constexpr (kM0 == 128) + else if constexpr(kM0 == 128) { return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); } diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 4383ce76cb..928ca83f65 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -261,7 +261,9 @@ struct BlockGemmPipelineAGmemBGmemCReg // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}, + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile @@ -269,7 +271,6 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); #if !defined(TOY_FA_FWD_OPT) @@ -279,10 +280,10 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); block_sync_lds(); }); #else @@ -322,10 +323,10 @@ struct BlockGemmPipelineAGmemBGmemCReg move_tile_window(b_copy_dram_window, {0, kKPerBlock}); block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - bWarpTile); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); block_sync_lds(); @@ -344,7 +345,7 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - bWarpTile); + bWarpTile); block_sync_lds(); } @@ -358,7 +359,7 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - bWarpTile); + bWarpTile); } #endif return c_block_tile; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 68afedda27..9b52143c92 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -21,7 +21,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return BlockGemmARegBSmemCRegV1{}; } - template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() { @@ -60,14 +59,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } - template __host__ __device__ static constexpr auto MakeADramTileDistribution() { return MakeARegBlockDescriptor(); } - template __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() { @@ -99,24 +96,24 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple( - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, make_tuple( - make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{})), make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } - template __host__ __device__ static constexpr auto MakeBDramTileDistribution() { diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp index c797714421..4ce61ed20c 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.cpp @@ -29,13 +29,13 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck_tile::half_t; - ck_tile::index_t Batch = 64; // Batch Number * Head Number - ck_tile::index_t M0 = 4096; // SequenceLengthQ - ck_tile::index_t N0 = 4096; // SequencelengthK - ck_tile::index_t K0 = 128; // HeadDim - ck_tile::index_t N1 = 128; // HeadDim - ck_tile::index_t verification = 0; - ck_tile::index_t init_method = 1; + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; if(argc == 3) { diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp index 7fc5a78806..4317ebee8d 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd.hpp @@ -12,7 +12,6 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "flash_attention_fwd_impl.hpp" - namespace ck_tile { CK_TILE_HOST_DEVICE static constexpr auto MakeBlock2TileMap(index_t M0, index_t N0) diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 689b557500..bffed23722 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -14,7 +14,6 @@ #include "block_gemm_areg_bsmem_creg_v1.hpp" #include "tile_gemm_shape.hpp" - namespace ck_tile { // S[M0, N0] = Q[M0, K0] * K[N0, K0] @@ -196,14 +195,16 @@ struct FlashAttentionFwdImpl // V LDS tile window for store auto v_copy_lds_window = make_tile_window(v_lds, - make_tuple(number{}, number{}), - {0, 0}, - v_dram_window.get_tile_distribution()); + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); // V LDS tile for block GEMM - auto v_lds_gemm_window = make_tile_window( - v_lds, make_tuple(number{}, number{}), {0, 0}, - make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); + auto v_lds_gemm_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); #else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); @@ -321,10 +322,10 @@ struct FlashAttentionFwdImpl store_tile(v_lds_window, v); block_sync_lds(); gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - v_lds_window); + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); block_sync_lds(); }); #else @@ -361,10 +362,10 @@ struct FlashAttentionFwdImpl move_tile_window(v_dram_window, {0, kK1PerBlock}); gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); vWarpTile = load_tile(v_lds_gemm_window); gemm1.template HotLoopScheduler<8, 4>(); @@ -373,23 +374,23 @@ struct FlashAttentionFwdImpl } // tail { - if constexpr (k1_loops > 1) + if constexpr(k1_loops > 1) { gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 2) * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } #endif diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp index e6b1707851..129a4c5ed5 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1.hpp @@ -32,7 +32,8 @@ struct BlockGemmARegBSmemCRegV1 // B block tile distribution for load from lds CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode() { - constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + constexpr auto config = + Policy::template GetWarpGemmMWarpNWarp(); using WG = remove_cvref_t())>; constexpr index_t MWarp = config.template get<1>(); @@ -55,7 +56,8 @@ struct BlockGemmARegBSmemCRegV1 return b_block_dstr_encode; } - static constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; + static constexpr auto BLdsTileDistr = + decltype(make_static_tile_distribution(MakeBBlockDistributionEncode())){}; using BLdsTile = decltype(make_static_distributed_tensor(BLdsTileDistr)); template @@ -149,15 +151,14 @@ struct BlockGemmARegBSmemCRegV1 // C += A * B template - __device__ void operator() (CBlockTensor& c_block_tensor, + __device__ void operator()(CBlockTensor& c_block_tensor, const ABlockTensorTmp& a_block_tensor_tmp, const BLdsTile& b_block_tensor_tmp) const { - static_assert( - std::is_same_v> && - std::is_same_v> && - std::is_same_v>, - "wrong!"); + static_assert(std::is_same_v> && + std::is_same_v> && + std::is_same_v>, + "wrong!"); constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = CBlockTensor{}.get_lengths()[number<1>{}]; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp index 6eee7f0d1e..8994689841 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_default_policy.hpp @@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1DefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - if constexpr (kM0 == 64) + if constexpr(kM0 == 64) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 4, 1); } - else if constexpr (kM0 == 32) + else if constexpr(kM0 == 32) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution{}, 2, 1); } - else if constexpr (kM0 == 128) + else if constexpr(kM0 == 128) { return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1); } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp index 3924a66daf..e3f3fd0cd6 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_areg_bsmem_creg_v1_iteratek_policy.hpp @@ -13,15 +13,15 @@ struct BlockGemmARegBSmemCRegV1K8Policy template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - if constexpr (kM0 == 64) + if constexpr(kM0 == 64) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 4, 1); } - else if constexpr (kM0 == 32) + else if constexpr(kM0 == 32) { return make_tuple(WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}, 2, 1); } - else if constexpr (kM0 == 128) + else if constexpr(kM0 == 128) { return make_tuple(WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution{}, 4, 1); } diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 4383ce76cb..928ca83f65 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -261,7 +261,9 @@ struct BlockGemmPipelineAGmemBGmemCReg // B LDS tile for block GEMM auto b_lds_gemm_window = make_tile_window( - b_lds_block, make_tuple(number{}, number{}), {0, 0}, + b_lds_block, + make_tuple(number{}, number{}), + {0, 0}, make_static_tile_distribution(block_gemm.MakeBBlockDistributionEncode())); // Acc register tile @@ -269,7 +271,6 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, 0>{}, sequence{}), b_lds_gemm_window)){}; - tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); #if !defined(TOY_FA_FWD_OPT) @@ -279,10 +280,10 @@ struct BlockGemmPipelineAGmemBGmemCReg store_tile(b_copy_lds_window, b_block_tile); block_sync_lds(); block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - b_copy_lds_window); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + b_copy_lds_window); block_sync_lds(); }); #else @@ -322,10 +323,10 @@ struct BlockGemmPipelineAGmemBGmemCReg move_tile_window(b_copy_dram_window, {0, kKPerBlock}); block_gemm(c_block_tile, - get_slice_tile(a_copy_reg_tensor, - sequence<0, i_k0 * kKPerBlock>{}, - sequence{}), - bWarpTile); + get_slice_tile(a_copy_reg_tensor, + sequence<0, i_k0 * kKPerBlock>{}, + sequence{}), + bWarpTile); block_sync_lds(); @@ -344,7 +345,7 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 2) * kKPerBlock>{}, sequence{}), - bWarpTile); + bWarpTile); block_sync_lds(); } @@ -358,7 +359,7 @@ struct BlockGemmPipelineAGmemBGmemCReg get_slice_tile(a_copy_reg_tensor, sequence<0, (k_loops - 1) * kKPerBlock>{}, sequence{}), - bWarpTile); + bWarpTile); } #endif return c_block_tile; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index 68afedda27..9b52143c92 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -21,7 +21,6 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return BlockGemmARegBSmemCRegV1{}; } - template __host__ __device__ static constexpr auto MakeARegBlockDescriptor() { @@ -60,14 +59,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } - template __host__ __device__ static constexpr auto MakeADramTileDistribution() { return MakeARegBlockDescriptor(); } - template __host__ __device__ static constexpr auto MakeBLdsBlockDescriptor() { @@ -99,24 +96,24 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( b_lds_block_desc_permuted, - make_tuple( - make_unmerge_transform(make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); constexpr auto b_lds_block_desc = transform_tensor_descriptor( b_lds_block_desc_xk0_mnldslayer_mn_xk1, make_tuple( - make_merge_transform(make_tuple(number{}, number{})), + make_merge_transform( + make_tuple(number{}, number{})), make_merge_transform(make_tuple(number{}, number{}))), make_tuple(sequence<1, 0>{}, sequence<2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); return b_lds_block_desc; } - template __host__ __device__ static constexpr auto MakeBDramTileDistribution() { diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp index 48680a218a..3750ede188 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.cpp @@ -29,13 +29,13 @@ int main(int argc, char* argv[]) using OaccDataType = float; using ODataType = ck_tile::half_t; - ck_tile::index_t Batch = 64; // Batch Number * Head Number - ck_tile::index_t M0 = 4096; // SequenceLengthQ - ck_tile::index_t N0 = 4096; // SequencelengthK - ck_tile::index_t K0 = 128; // HeadDim - ck_tile::index_t N1 = 128; // HeadDim - ck_tile::index_t verification = 0; - ck_tile::index_t init_method = 1; + ck_tile::index_t Batch = 64; // Batch Number * Head Number + ck_tile::index_t M0 = 4096; // SequenceLengthQ + ck_tile::index_t N0 = 4096; // SequencelengthK + ck_tile::index_t K0 = 128; // HeadDim + ck_tile::index_t N1 = 128; // HeadDim + ck_tile::index_t verification = 0; + ck_tile::index_t init_method = 1; if(argc == 3) { diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp index dcb901c0a2..38c56a27e8 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd.hpp @@ -155,7 +155,6 @@ struct FlashAttentionFwd const index_t iN1 = __builtin_amdgcn_readfirstlane(id_tile.template get(number<1>{}) % num_tile_n1 * kN1PerBlock); - const auto kernel_impl = FlashAttentionFwdImpl{}, number{}), - {0, 0}, - v_dram_window.get_tile_distribution()); + make_tuple(number{}, number{}), + {0, 0}, + v_dram_window.get_tile_distribution()); // V LDS tile for block GEMM - auto v_lds_gemm_window = make_tile_window( - v_lds, make_tuple(number{}, number{}), {0, 0}, - make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); + auto v_lds_gemm_window = + make_tile_window(v_lds, + make_tuple(number{}, number{}), + {0, 0}, + make_static_tile_distribution(gemm1.MakeBBlockDistributionEncode())); #else auto v_lds_window = make_tile_window( v_lds, make_tuple(number{}, number{}), {0, 0}); @@ -321,10 +322,10 @@ struct FlashAttentionFwdImpl store_tile(v_lds_window, v); block_sync_lds(); gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - v_lds_window); + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + v_lds_window); block_sync_lds(); }); #else @@ -361,10 +362,10 @@ struct FlashAttentionFwdImpl move_tile_window(v_dram_window, {0, kK1PerBlock}); gemm1(o_acc, - get_slice_tile(p, - sequence<0, i_k1 * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, i_k1 * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); vWarpTile = load_tile(v_lds_gemm_window); gemm1.template HotLoopScheduler<8, 4>(); @@ -373,23 +374,23 @@ struct FlashAttentionFwdImpl } // tail { - if constexpr (k1_loops > 1) + if constexpr(k1_loops > 1) { gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 2) * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, (k1_loops - 2) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } store_tile(v_copy_lds_window, v_prefetch); block_sync_lds(); vWarpTile = load_tile(v_lds_gemm_window); gemm1(o_acc, - get_slice_tile(p, - sequence<0, (k1_loops - 1) * kK1PerBlock>{}, - sequence{}), - vWarpTile); + get_slice_tile(p, + sequence<0, (k1_loops - 1) * kK1PerBlock>{}, + sequence{}), + vWarpTile); block_sync_lds(); } #endif diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py index 00bc91cadc..10def9a5dd 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/generate.py @@ -1,33 +1,18 @@ -# SPDX-License-Identifier: MIT -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +#SPDX - License - Identifier : MIT +#Copyright(c) 2025, Advanced Micro Devices, Inc.All rights reserved. -import argparse -from enum import IntEnum -from pathlib import Path -import sys -from typing import List, Optional, Any -import functools -import itertools -import copy -from dataclasses import dataclass +import argparse from enum import IntEnum from pathlib import Path import sys from typing import List, Optional, Any import functools import itertools import copy from dataclasses import dataclass -def get_if_str(size_, total, last_else=True): - if size_ == "head_dim_256_seq_4096": - return 'if' - else: - return 'else if' + def get_if_str(size_, total, last_else = True) : if size_ == "head_dim_256_seq_4096" : return 'if' else : return 'else if' -DATA_TYPE_MAP = {'fp32': 'float', - 'fp16': 'ck_tile::half_t', - 'bf16': 'ck_tile::bf16_t'} + DATA_TYPE_MAP = {'fp32' : 'float', 'fp16' : 'ck_tile::half_t', 'bf16' : 'ck_tile::bf16_t' } -def BOOL_MAP(b_) -> str: - return 'true' if b_ else 'false' + def BOOL_MAP(b_)->str: return 'true' if b_ else 'false' -class FlashAttentionFwdCodegen: - API_TRAITS_DEFINE = """ + class FlashAttentionFwdCodegen:API_TRAITS_DEFINE = "" + " -template -struct flash_attention_fwd_traits_ -{ - using SaccDataType = ck_tile::remove_cvref_t; - using SMPLComputeDataType = ck_tile::remove_cvref_t; - using PDataType = ck_tile::remove_cvref_t; - using OaccDataType = ck_tile::remove_cvref_t; + index_t kK1PerBlock_ = 64> struct flash_attention_fwd_traits_{using SaccDataType = ck_tile::remove_cvref_t ; +using SMPLComputeDataType = ck_tile::remove_cvref_t; +using PDataType = ck_tile::remove_cvref_t; +using OaccDataType = ck_tile::remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kHeadDim = kHeadDim_; - static constexpr index_t kM0PerBlock = kM0PerBlock_; - static constexpr index_t kN0PerBlock = kN0PerBlock_; - static constexpr index_t kK0PerBlock = kK0PerBlock_; - static constexpr index_t kN1PerBlock = kN1PerBlock_; - static constexpr index_t kK1PerBlock = kK1PerBlock_; +static constexpr index_t kBlockSize = kBlockSize_; +static constexpr index_t kHeadDim = kHeadDim_; +static constexpr index_t kM0PerBlock = kM0PerBlock_; +static constexpr index_t kN0PerBlock = kN0PerBlock_; +static constexpr index_t kK0PerBlock = kK0PerBlock_; +static constexpr index_t kN1PerBlock = kN1PerBlock_; +static constexpr index_t kK1PerBlock = kK1PerBlock_; - static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); - static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -}; +static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD +static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / get_warp_size(); +static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; +} +; template using traits_ = flash_attention_fwd_traits_; -""" + SMPLComputeDataType, + PDataType, + OaccDataType, + kBlockSize, + kHeadDim, + kM0PerBlock, + kN0PerBlock, + kK0PerBlock, + kN1PerBlock, + kK1PerBlock>; +"" + " - API_BASE = """ + API_BASE = "" + " // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include #include "flash_attention_fwd.hpp" -namespace ck_tile {{ + namespace ck_tile +{ + { -{F_traits_define} + { + F_traits_define + } -// Note: this internal API only declare, not define here, otherwise will block `make -j` -template -float flash_attention_fwd_(const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config); + // Note: this internal API only declare, not define here, otherwise will block `make -j` + template + float flash_attention_fwd_( + const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config); -template -float flash_attention_fwd(const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config) {{ - float r = -1; -{F_dispatch} - return r; -}} - -template float flash_attention_fwd( - const FlashAttnArgs&, - const ck_tile::stream_config&); - -}} -""" - - API_INNER_CASE = """ {F_if} {F_VEC_COND} - r = flash_attention_fwd_>(a, stream_config); -""" - - INSTANCE_BASE = """ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#include "flash_attention_fwd_api_common.hpp" - -namespace ck_tile { -// clang-format off -// -{F_instance_def} -// clang-format on + template + float flash_attention_fwd( + const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) + { + { + float r = -1; + { + F_dispatch + } + return r; + } + } + template float flash_attention_fwd(const FlashAttnArgs&, + const ck_tile::stream_config&); + } } -""" +"" + " - def __init__(self, working_path, kernel_filter): - self.working_path = working_path - self.kernel_filter = kernel_filter + API_INNER_CASE = "" + " {F_if} {F_VEC_COND} + r = flash_attention_fwd_>( + a, stream_config); +"" + " - @dataclass - class h_traits: - F_SaccDataType: str - F_SMPLComputeDataType: str - F_PDataType: str - F_OaccDataType: str - F_kBlockSize: int - F_kHeadDim: int - F_kM0PerBlock: int - F_kN0PerBlock: int - F_kK0PerBlock: int - F_kN1PerBlock: int - F_kK1PerBlock: int - - @property - def trait_name(self) -> str: - return (f"{DATA_TYPE_MAP[self.F_SaccDataType]}, " - f"{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " - f"{DATA_TYPE_MAP[self.F_PDataType]}, " - f"{DATA_TYPE_MAP[self.F_OaccDataType]}, " - f"{self.F_kBlockSize}, {self.F_kHeadDim}, " - f"{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " - f"{self.F_kN1PerBlock}, {self.F_kK1PerBlock}") - - @property - def def_name(self) -> str: - return (f"template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " - f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " - f"traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " - f"{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, " - "const ck_tile::stream_config&);") - - @dataclass - class h_instance: - F_DataTypePair: str # "q,k,v,o" - F_SizeCategory: str # "small", "medium", "large" - instance_list: List[Any] # List[h_traits] - - INSTANCE_BASE = """ + INSTANCE_BASE = "" + " // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #include "flash_attention_fwd_api_common.hpp" -namespace ck_tile {{ -// clang-format off + namespace ck_tile +{ + // clang-format off // {F_instance_def} -// clang-format on -}} + // clang-format on +} +"" + " + + def + __init__(self, working_path, kernel_filter) + : self.working_path = working_path self.kernel_filter = kernel_filter + + @dataclass class h_traits + : F_SaccDataType : str F_SMPLComputeDataType : str F_PDataType : str F_OaccDataType + : str F_kBlockSize : int F_kHeadDim : int F_kM0PerBlock : int F_kN0PerBlock : int F_kK0PerBlock + : int F_kN1PerBlock : int F_kK1PerBlock : int + + @property def trait_name(self) + ->str + : return (f "{DATA_TYPE_MAP[self.F_SaccDataType]}, " f + "{DATA_TYPE_MAP[self.F_SMPLComputeDataType]}, " f + "{DATA_TYPE_MAP[self.F_PDataType]}, " f "{DATA_TYPE_MAP[self.F_OaccDataType]}, " f + "{self.F_kBlockSize}, {self.F_kHeadDim}, " f + "{self.F_kM0PerBlock}, {self.F_kN0PerBlock}, {self.F_kK0PerBlock}, " f + "{self.F_kN1PerBlock}, {self.F_kK1PerBlock}") + + @property def def_name(self) + ->str + : return (f "template float flash_attention_fwd_<{DATA_TYPE_MAP['fp16']}, " f + "{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, " f + "traits_<{self.trait_name}>>(const FlashAttnArgs<{DATA_TYPE_MAP['fp16']}, " f + "{DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}, {DATA_TYPE_MAP['fp16']}>&, " + "const ck_tile::stream_config&);") + + @dataclass class h_instance : F_DataTypePair : str #"q,k,v,o" F_SizeCategory : str + #"small", + "medium", + "large" instance_list : List[Any] #List[h_traits] + + INSTANCE_BASE = "" + " +// SPDX-License-Identifier: MIT +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + +#include "flash_attention_fwd_api_common.hpp" + + namespace ck_tile +{ + { + // clang-format off +// +{F_instance_def} + // clang-format on + }} """ @property @@ -226,123 +229,133 @@ namespace ck_tile {{ #include "flash_attention_fwd.hpp" -namespace ck_tile {{ +namespace ck_tile +{ + { -template -struct flash_attention_fwd_traits_ -{{ - using SaccDataType = ck_tile::remove_cvref_t; - using SMPLComputeDataType = ck_tile::remove_cvref_t; - using PDataType = ck_tile::remove_cvref_t; - using OaccDataType = ck_tile::remove_cvref_t; + template + struct flash_attention_fwd_traits_ + { + { + using SaccDataType = ck_tile::remove_cvref_t; + using SMPLComputeDataType = ck_tile::remove_cvref_t; + using PDataType = ck_tile::remove_cvref_t; + using OaccDataType = ck_tile::remove_cvref_t; - static constexpr index_t kBlockSize = kBlockSize_; - static constexpr index_t kHeadDim = kHeadDim_; - static constexpr index_t kM0PerBlock = kM0PerBlock_; - static constexpr index_t kN0PerBlock = kN0PerBlock_; - static constexpr index_t kK0PerBlock = kK0PerBlock_; - static constexpr index_t kN1PerBlock = kN1PerBlock_; - static constexpr index_t kK1PerBlock = kK1PerBlock_; + static constexpr index_t kBlockSize = kBlockSize_; + static constexpr index_t kHeadDim = kHeadDim_; + static constexpr index_t kM0PerBlock = kM0PerBlock_; + static constexpr index_t kN0PerBlock = kN0PerBlock_; + static constexpr index_t kK0PerBlock = kK0PerBlock_; + static constexpr index_t kN1PerBlock = kN1PerBlock_; + static constexpr index_t kK1PerBlock = kK1PerBlock_; - static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD - static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; - static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; -}}; + static constexpr ck_tile::index_t kWarpPerCu = 8; // 2 warps per SIMD + static constexpr ck_tile::index_t kWarpPerBlock = kBlockSize / warpSize; + static constexpr ck_tile::index_t kBlockPerCu = kWarpPerCu / kWarpPerBlock; + } + }; + template + using traits_ = flash_attention_fwd_traits_; -template -using traits_ = flash_attention_fwd_traits_; + template + float flash_attention_fwd_( + const FlashAttnArgs& a, + const ck_tile::stream_config& stream_config) + { + { + using SaccDataType = typename Traits_::SaccDataType; + using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; + using PDataType = typename Traits_::PDataType; + using OaccDataType = typename Traits_::OaccDataType; + index_t kGridSize = + a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); -template -float flash_attention_fwd_(const FlashAttnArgs& a, - const ck_tile::stream_config& stream_config) {{ - using SaccDataType = typename Traits_::SaccDataType; - using SMPLComputeDataType = typename Traits_::SMPLComputeDataType; - using PDataType = typename Traits_::PDataType; - using OaccDataType = typename Traits_::OaccDataType; + if(stream_config.log_level_ > 0) + std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," + << Traits_::kHeadDim << ">" << std::flush; - index_t kGridSize = a.Batch * (a.M0 / Traits_::kM0PerBlock) * (a.N1 / Traits_::kN1PerBlock); - - if(stream_config.log_level_ > 0) - std::cout << ", " << "FlashAttentionFwd<" << Traits_::kBlockSize << "," << Traits_::kHeadDim << ">" << std::flush; - - return ck_tile::launch_kernel(stream_config, - ck_tile::make_kernel( - ck_tile::FlashAttentionFwd{{}}, - kGridSize, - Traits_::kBlockSize, - 0, - a.q_ptr, - a.k_ptr, - a.v_ptr, - a.o_ptr, - a.M0, - a.N0, - a.K0, - a.N1, - a.Batch, - a.strideQ, // StrideQ - a.strideK, // StrideK - a.strideV, // StrideV - a.strideO, // StrideO - a.batchStrideQ, // BatchStrideQ - a.batchStrideK, // BatchStrideK - a.batchStrideV, // BatchStrideV - a.batchStrideO)); // BatchStrideO -}} -}} + return ck_tile::launch_kernel( + stream_config, + ck_tile::make_kernel( + ck_tile::FlashAttentionFwd{{}}, + kGridSize, + Traits_::kBlockSize, + 0, + a.q_ptr, + a.k_ptr, + a.v_ptr, + a.o_ptr, + a.M0, + a.N0, + a.K0, + a.N1, + a.Batch, + a.strideQ, // StrideQ + a.strideK, // StrideK + a.strideV, // StrideV + a.strideO, // StrideO + a.batchStrideQ, // BatchStrideQ + a.batchStrideK, // BatchStrideK + a.batchStrideV, // BatchStrideV + a.batchStrideO)); // BatchStrideO + } + } + } +} """ def content_api(self, args) -> str: - # Sort based on dtype +#Sort based on dtype t_dtype_dict = {} blobs = self.get_blobs(args) @@ -402,7 +415,7 @@ float flash_attention_fwd_(const FlashAttnArgs