diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 92b98a2f11..4a8c2beeb7 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -58,8 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); ignore = a_element_func; diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index cdce1b1f31..2cafb715a2 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -91,6 +91,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + return MakeARegBlockDescriptor(); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp index 4229db5250..234ff8821a 100644 --- a/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/03_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -140,7 +140,7 @@ struct FlashAttentionFwdImpl // Q/K/V DRAM and DRAM window const auto q_dram = make_naive_tensor_view( - q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); + q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); const auto k_dram = make_naive_tensor_view( k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{}); @@ -149,7 +149,10 @@ struct FlashAttentionFwdImpl v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{}); auto q_dram_window = make_tile_window( - q_dram, make_tuple(number{}, number{}), {iM0, 0}); + q_dram, + make_tuple(number{}, number{}), + {iM0, 0}, + BlockGemm0Policy::template MakeADramTileDistribution()); auto k_dram_window = make_tile_window( k_dram, make_tuple(number{}, number{}), {0, 0}); @@ -160,9 +163,8 @@ struct FlashAttentionFwdImpl {iN1, 0}, MakeVDramTileDistribution()); - // Q in Register - auto q_reg_tensor = make_static_distributed_tensor( - BlockGemm0Policy::template MakeARegBlockDescriptor()); + // Q in register + auto q_reg_tensor = load_tile(q_dram_window); // V LDS and LDS window // V LDS occupies the same LDS allocation Q/K LDS @@ -212,15 +214,10 @@ struct FlashAttentionFwdImpl // loop over Column of S (J loop) index_t iN0 = 0; - // Cold Q_Reg_Cache - s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr); do { - // Hot Q_Reg_Cache - if(iN0 > 0) - { - s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); - } + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + // S{j} const auto s = tile_elementwise_in(type_convert, s_acc); diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp index 92b98a2f11..4a8c2beeb7 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds.hpp @@ -58,8 +58,7 @@ struct BlockGemmPipelineAGmemBGmemCReg< "wrong!"); static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}], + kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!"); ignore = a_element_func; diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp index cdce1b1f31..2cafb715a2 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/block_gemm_pipeline_agmem_bgmem_creg_v2_askiplds_policy.hpp @@ -91,6 +91,12 @@ struct BlockGemmPipelineAGmemBGmemCRegSkipALdsPersistentQRegCachePolicy return a_block_dstr; } + + template + __host__ __device__ static constexpr auto MakeADramTileDistribution() + { + return MakeARegBlockDescriptor(); + } }; } // namespace ck_tile diff --git a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp index 4229db5250..234ff8821a 100644 --- a/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp +++ b/example/ck_tile/99_toy_example/04_codegen_flash_attention_fwd/flash_attention_fwd_impl.hpp @@ -140,7 +140,7 @@ struct FlashAttentionFwdImpl // Q/K/V DRAM and DRAM window const auto q_dram = make_naive_tensor_view( - q_ptr, make_tuple(M0, K0), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); + q_ptr, make_tuple(M0, kHeadDim), make_tuple(StrideQ, 1), number<32>{}, number<1>{}); const auto k_dram = make_naive_tensor_view( k_ptr, make_tuple(N0, K0), make_tuple(StrideK, 1), number<32>{}, number<1>{}); @@ -149,7 +149,10 @@ struct FlashAttentionFwdImpl v_ptr, make_tuple(N1, N0), make_tuple(StrideV, 1), number<32>{}, number<1>{}); auto q_dram_window = make_tile_window( - q_dram, make_tuple(number{}, number{}), {iM0, 0}); + q_dram, + make_tuple(number{}, number{}), + {iM0, 0}, + BlockGemm0Policy::template MakeADramTileDistribution()); auto k_dram_window = make_tile_window( k_dram, make_tuple(number{}, number{}), {0, 0}); @@ -160,9 +163,8 @@ struct FlashAttentionFwdImpl {iN1, 0}, MakeVDramTileDistribution()); - // Q in Register - auto q_reg_tensor = make_static_distributed_tensor( - BlockGemm0Policy::template MakeARegBlockDescriptor()); + // Q in register + auto q_reg_tensor = load_tile(q_dram_window); // V LDS and LDS window // V LDS occupies the same LDS allocation Q/K LDS @@ -212,15 +214,10 @@ struct FlashAttentionFwdImpl // loop over Column of S (J loop) index_t iN0 = 0; - // Cold Q_Reg_Cache - s_acc = gemm0_pipeline(q_dram_window, k_dram_window, q_reg_tensor, smem_ptr); do { - // Hot Q_Reg_Cache - if(iN0 > 0) - { - s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); - } + s_acc = gemm0_pipeline(k_dram_window, q_reg_tensor, smem_ptr); + // S{j} const auto s = tile_elementwise_in(type_convert, s_acc);