From 1edd7a331e5a2e980a5be76df741fa255cb23c00 Mon Sep 17 00:00:00 2001 From: aska-0096 Date: Fri, 18 Apr 2025 02:07:32 +0000 Subject: [PATCH] new layout sanity checked. not correct instruction generated --- .../ck_tile/01_fmha/example_bwd_fmha_bf16.cpp | 55 ++- include/ck_tile/core/tensor/slice_tile.hpp | 6 +- include/ck_tile/host/check_err.hpp | 12 +- .../ops/fmha/kernel/fmha_bwd_kernel.hpp | 8 +- ...a_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp | 268 ++++++++++--- ...block_fmha_bwd_pipeline_default_policy.hpp | 351 ++++++++++++++++-- include/ck_tile/ops/gemm/warp/warp_gemm.hpp | 10 + .../gemm/warp/warp_gemm_attribute_mfma.hpp | 208 +++++++---- .../ops/gemm/warp/warp_gemm_dispatcher.hpp | 11 +- 9 files changed, 770 insertions(+), 159 deletions(-) diff --git a/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp b/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp index 2c1264eee3..46018b99b0 100644 --- a/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp +++ b/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp @@ -600,6 +600,54 @@ bool run(const ck_tile::ArgParser& arg_parser) ck_tile::FillTrigValue{}(bias_host); ck_tile::FillTrigValue{}(do_host); } + else if(init_method == 3) + { + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(do_host); + } + else if(init_method == 4) + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(do_host); + } + else if(init_method == 5) + { + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(do_host); + } + else if(init_method == 6) + { + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(do_host); + } + else if(init_method == 7) + { + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(do_host); + } + else if(init_method == 8) + { + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(q_host); + ck_tile::FillUniformDistribution{0.f, 1.f, seed}(k_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(v_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(bias_host); + ck_tile::FillUniformDistribution{1.f, 1.f, seed}(do_host); + } if(bias.type == bias_enum::alibi) { auto slopes = ck_tile::get_alibi_slopes(nhead); @@ -619,9 +667,12 @@ bool run(const ck_tile::ArgParser& arg_parser) } } - // for(int iM=0; iM<128; iM++){ // for(int iK=0; iK<16; iK++){ - // printf("%04x ", *(reinterpret_cast(&(q_host(0, 0, iK, iM))))); + // for(int iM=0; iM<128; iM++){ + // printf("%04x ", *(reinterpret_cast(&(q_host(0, 0, iK, iM))))); + // if(iM%16==15){ + // printf("|"); + // } // } // printf("\n"); // } diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp index d51b4c92fb..3b696d8cc8 100644 --- a/include/ck_tile/core/tensor/slice_tile.hpp +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -76,6 +76,7 @@ set_slice_tile(static_distributed_tensor slice_ends) { using DstDistribution = remove_cvref_t; + using SrcDistribution = remove_cvref_t; constexpr auto sliced_dstr_yidx_ylen = detail::slice_distribution_from_x(DstDistribution{}, slice_begins, slice_ends); @@ -84,9 +85,10 @@ set_slice_tile(static_distributed_tensor(); constexpr auto sliced_y_lengths = sliced_dstr_yidx_ylen.template at<2>(); - static_assert(std::is_same_v, "wrong!"); + static_assert(std::is_same_v, SrcDistribution>, "wrong!"); - dst_tile.set_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); + dst_tile.set_y_sliced_thread_data( + sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); } } // namespace ck_tile diff --git a/include/ck_tile/host/check_err.hpp b/include/ck_tile/host/check_err.hpp index 745c18d6dd..6174822f7e 100644 --- a/include/ck_tile/host/check_err.hpp +++ b/include/ck_tile/host/check_err.hpp @@ -186,7 +186,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 5000) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -246,7 +246,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 5000) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -305,7 +305,7 @@ check_err(const Range& out, { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 5000) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -360,7 +360,7 @@ std::enable_if_t<(std::is_same_v, ranges::range_val { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 5000) { std::cerr << msg << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; @@ -437,7 +437,7 @@ std::enable_if_t<(std::is_same_v, ranges::range_val { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 5000) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o_fp64 << " != " << r_fp64 << std::endl; @@ -495,7 +495,7 @@ std::enable_if_t<(std::is_same_v, ranges::range_val { max_err = err > max_err ? err : max_err; err_count++; - if(err_count < 5) + if(err_count < 5000) { std::cerr << msg << std::setw(12) << std::setprecision(7) << " out[" << i << "] != ref[" << i << "]: " << o << " != " << r << std::endl; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index 21467f1d85..697281b81c 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -1277,14 +1277,18 @@ struct FmhaBwdDQDKDVKernel make_tuple(number{}, number{}), {0, 0}); + constexpr auto kSeq0 = 64; + auto k_dram_window = make_tile_window( k_dram, - make_tuple(number{}, number{}), + // make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_n0, 0}); auto v_dram_window = make_tile_window( v_dram, - make_tuple(number{}, number{}), + // make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_n0, 0}); auto do_dram_window = make_tile_window( diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 509a5c4a25..4fe05de3aa 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -125,8 +125,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP "wrong!"); static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && - kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + // kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && + // kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && @@ -138,7 +138,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP "wrong!"); // if (threadIdx.x == 0){ - // HotLoopScheduler::print(); + // // HotLoopScheduler::print(); // } // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); @@ -175,19 +175,21 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP return make_tuple(dk_acc, dv_acc); } } + constexpr auto kSeq0 = 64; + KDataType* k_lds_ptr = static_cast(static_cast(static_cast(smem_ptr))); auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), k_lds_write_window.get_window_origin(), - Policy::template MakeKRegBlockDescriptor()); + Policy::template MakeKRegSliceBlockDescriptor()); auto k_reg_tensor = make_static_distributed_tensor( Policy::template MakeKRegBlockDescriptor()); @@ -207,39 +209,46 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), v_lds_write_window.get_window_origin(), - Policy::template MakeVRegBlockDescriptor()); + Policy::template MakeVRegSliceBlockDescriptor()); + + auto v_reg_tensor = make_static_distributed_tensor( + Policy::template MakeVRegBlockDescriptor()); //------------------------------------------------------------------ // KT, Reg ->LDS ->Reg - auto shuffled_k_block_tile = make_static_distributed_tensor( + auto kt_block_tile = make_static_distributed_tensor( Policy::template MakeShuffledKRegWriteBlockDescriptor()); KDataType* kt_lds_ptr = static_cast(static_cast( static_cast(smem_ptr) + Policy::template GetSmemSizeK())); - auto shuffled_k_lds_write = make_tensor_view( + auto kt_lds_write = make_tensor_view( kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); - auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + auto kt_lds_write_window = make_tile_window( + kt_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); auto kt_lds_read_window = make_tile_window(kt_lds_read, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, - Policy::template MakeKTRegBlockDescriptor()); + Policy::template MakeKTRegSliceBlockDescriptor()); + + auto kt_reg_tensor = make_static_distributed_tensor( + Policy::template MakeKTRegBlockDescriptor()); //------------------------------------------------------------------ // Pre-Load KV into Registers +#if 0 auto k_block_tile = load_tile(k_dram_window); auto v_block_tile = load_tile(v_dram_window); @@ -259,31 +268,108 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto v_reg_tensor = load_tile(v_lds_read_window); -#if 0 - constexpr auto kSeq0 = 64; +#elif 1 // Looped data loading static_for<0, kN0 / kSeq0, 1>{}([&](auto i_n0) { auto k_block_tile = load_tile(k_dram_window); +#if 0 + if(get_block_1d_id()==0 && get_thread_local_1d_id()<256){ + printf("iter: %01d, Tid: %03d, K_global_read: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x |\n", + i_n0.value, get_thread_local_1d_id(), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<0>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<1>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<2>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<3>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<4>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<5>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<6>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<7>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<0 + 8>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<1 + 8>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<2 + 8>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<3 + 8>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<4 + 8>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<5 + 8>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<6 + 8>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<7 + 8>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<0 + 16>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<1 + 16>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<2 + 16>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<3 + 16>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<4 + 16>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<5 + 16>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<6 + 16>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<7 + 16>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<0 + 24>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<1 + 24>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<2 + 24>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<3 + 24>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<4 + 24>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<5 + 24>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<6 + 24>{}]))), + *(reinterpret_cast(&(k_block_tile.get_thread_buffer()[number<7 + 24>{}]))) + ); + } +#endif move_tile_window(k_dram_window, {kSeq0, 0}); store_tile(k_lds_write_window, k_block_tile); - shuffle_distributed_tensor(kt_block_tile, k_block_tile); + shuffle_tile(kt_block_tile, k_block_tile); store_tile(kt_lds_write_window, kt_block_tile); block_sync_lds(); auto k_reg_tensor_slice = load_tile(k_lds_read_window); +#if 0 + if(get_block_1d_id()==0 && get_thread_local_1d_id()<256){ + printf("iter: %01d, Tid: %03d, K_lds_read: %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x | %04x %04x %04x %04x %04x %04x %04x %04x |\n", + i_n0.value, get_thread_local_1d_id(), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<0>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<1>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<2>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<3>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<4>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<5>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<6>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<7>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<0 + 8>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<1 + 8>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<2 + 8>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<3 + 8>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<4 + 8>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<5 + 8>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<6 + 8>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<7 + 8>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<0 + 16>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<1 + 16>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<2 + 16>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<3 + 16>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<4 + 16>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<5 + 16>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<6 + 16>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<7 + 16>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<0 + 24>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<1 + 24>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<2 + 24>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<3 + 24>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<4 + 24>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<5 + 24>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<6 + 24>{}]))), + *(reinterpret_cast(&(k_reg_tensor_slice.get_thread_buffer()[number<7 + 24>{}]))) + ); + } +#endif set_slice_tile(k_reg_tensor, k_reg_tensor_slice, - Sequence{}, - Sequence<(i_n0+1)*kSeq0, kQKHeaddim>{}); + sequence{}, + sequence<(i_n0 + 1) * kSeq0, kQKHeaddim>{}); auto kt_reg_tensor_slice = load_tile(kt_lds_read_window); set_slice_tile(kt_reg_tensor, kt_reg_tensor_slice, - Sequence<0, i_n0*kSeq0>{}, - Sequence{}); + sequence<0, i_n0 * kSeq0>{}, + sequence{}); block_sync_lds(); }); @@ -298,10 +384,32 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto v_reg_tensor_slice = load_tile(v_lds_read_window); set_slice_tile(v_reg_tensor, v_reg_tensor_slice, - Sequence{}, - Sequence<(i_n0+1)*kSeq0, kVHeaddim>{}); + sequence{}, + sequence<(i_n0 + 1) * kSeq0, kVHeaddim>{}); block_sync_lds(); }); +#if 0 + if(get_block_1d_id()==0 && get_thread_local_1d_id()<256){ + printf("Tid: %03d, K: %04x %04x %04x %04x %04x %04x %04x %04x \n", + get_thread_local_1d_id(), + *(reinterpret_cast(&(k_reg_tensor.get_thread_buffer()[number<0>{}]))), + *(reinterpret_cast(&(k_reg_tensor.get_thread_buffer()[number<1>{}]))), + *(reinterpret_cast(&(k_reg_tensor.get_thread_buffer()[number<2>{}]))), + *(reinterpret_cast(&(k_reg_tensor.get_thread_buffer()[number<3>{}]))), + *(reinterpret_cast(&(k_reg_tensor.get_thread_buffer()[number<4>{}]))), + *(reinterpret_cast(&(k_reg_tensor.get_thread_buffer()[number<5>{}]))), + *(reinterpret_cast(&(k_reg_tensor.get_thread_buffer()[number<6>{}]))), + *(reinterpret_cast(&(k_reg_tensor.get_thread_buffer()[number<7>{}])))); + } +#endif #endif //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS @@ -597,11 +705,32 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP move_tile_window(do_dram_window, {kM0, 0}); s_acc = gemm_0(q_reg_tensor, k_reg_tensor); - +#if 0 + if(get_block_1d_id()==0 && get_thread_local_1d_id()<64){ + printf("Tid: %02d, Q: %04x %04x %04x %04x %04x %04x %04x %04x\n", + get_thread_local_1d_id(), + *(reinterpret_cast(&(q_reg_tensor.get_thread_buffer()[number<0>{}]))), + *(reinterpret_cast(&(q_reg_tensor.get_thread_buffer()[number<1>{}]))), + *(reinterpret_cast(&(q_reg_tensor.get_thread_buffer()[number<2>{}]))), + *(reinterpret_cast(&(q_reg_tensor.get_thread_buffer()[number<3>{}]))), + *(reinterpret_cast(&(q_reg_tensor.get_thread_buffer()[number<4>{}]))), + *(reinterpret_cast(&(q_reg_tensor.get_thread_buffer()[number<5>{}]))), + *(reinterpret_cast(&(q_reg_tensor.get_thread_buffer()[number<6>{}]))), + *(reinterpret_cast(&(q_reg_tensor.get_thread_buffer()[number<7>{}])))); + } +#endif auto dot_reg_tensor = load_tile(dot_lds_read_window); - HotLoopScheduler::template GemmStagedScheduler<0>(); - __builtin_amdgcn_sched_barrier(0); + // HotLoopScheduler::template GemmStagedScheduler<0>(); + // __builtin_amdgcn_sched_barrier(0); // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) { @@ -715,9 +844,54 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto qt_reg_tensor = load_tile(qt_lds_read_window); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); - - HotLoopScheduler::template GemmStagedScheduler<1>(); - __builtin_amdgcn_sched_barrier(0); +#if 0 + if(get_block_1d_id()==0 && get_thread_local_1d_id()<64){ + printf("Tid: %02d, Pt: %04x %04x %04x %04x %04x %04x %04x %04x DoT: %04x %04x %04x %04x %04x %04x %04x %04x dv_acc: %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf\n", + get_thread_local_1d_id(), + *(reinterpret_cast(&(pt_reg_tensor.get_thread_buffer()[number<0>{}]))), + *(reinterpret_cast(&(pt_reg_tensor.get_thread_buffer()[number<1>{}]))), + *(reinterpret_cast(&(pt_reg_tensor.get_thread_buffer()[number<2>{}]))), + *(reinterpret_cast(&(pt_reg_tensor.get_thread_buffer()[number<3>{}]))), + *(reinterpret_cast(&(pt_reg_tensor.get_thread_buffer()[number<4>{}]))), + *(reinterpret_cast(&(pt_reg_tensor.get_thread_buffer()[number<5>{}]))), + *(reinterpret_cast(&(pt_reg_tensor.get_thread_buffer()[number<6>{}]))), + *(reinterpret_cast(&(pt_reg_tensor.get_thread_buffer()[number<7>{}]))), + *(reinterpret_cast(&(dot_reg_tensor.get_thread_buffer()[number<0>{}]))), + *(reinterpret_cast(&(dot_reg_tensor.get_thread_buffer()[number<1>{}]))), + *(reinterpret_cast(&(dot_reg_tensor.get_thread_buffer()[number<2>{}]))), + *(reinterpret_cast(&(dot_reg_tensor.get_thread_buffer()[number<3>{}]))), + *(reinterpret_cast(&(dot_reg_tensor.get_thread_buffer()[number<4>{}]))), + *(reinterpret_cast(&(dot_reg_tensor.get_thread_buffer()[number<5>{}]))), + *(reinterpret_cast(&(dot_reg_tensor.get_thread_buffer()[number<6>{}]))), + *(reinterpret_cast(&(dot_reg_tensor.get_thread_buffer()[number<7>{}]))), + dv_acc.get_thread_buffer()[number<0>{}], + dv_acc.get_thread_buffer()[number<1>{}], + dv_acc.get_thread_buffer()[number<2>{}], + dv_acc.get_thread_buffer()[number<3>{}], + dv_acc.get_thread_buffer()[number<4>{}], + dv_acc.get_thread_buffer()[number<5>{}], + dv_acc.get_thread_buffer()[number<6>{}], + dv_acc.get_thread_buffer()[number<7>{}]); + } +#endif + // HotLoopScheduler::template GemmStagedScheduler<1>(); + // __builtin_amdgcn_sched_barrier(0); // STAGE 4, OGrad@V Gemm2 auto dp_acc = SPGradBlockTileType{}; @@ -737,8 +911,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP store_tile(d_lds_write_window, d_block_tile); - HotLoopScheduler::template GemmStagedScheduler<2>(); - __builtin_amdgcn_sched_barrier(0); + // HotLoopScheduler::template GemmStagedScheduler<2>(); + // __builtin_amdgcn_sched_barrier(0); // STAGE 5, P^T(PGrad^T - D) auto ds = SPGradBlockTileType{}; constexpr auto ds_spans = decltype(ds)::get_distributed_spans(); @@ -848,8 +1022,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP q_reg_tensor = load_tile(q_lds_read_window); lse = load_tile(lse_lds_read_window); - HotLoopScheduler::template GemmStagedScheduler<3>(); - __builtin_amdgcn_sched_barrier(0); + // HotLoopScheduler::template GemmStagedScheduler<3>(); + // __builtin_amdgcn_sched_barrier(0); // STAGE7 SGrad@K^T Gemm4 auto dq_acc = QGradBlockTileType{}; clear_tile(dq_acc); @@ -875,7 +1049,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP do_reg_tensor = load_tile(do_lds_read_window); d = load_tile(d_lds_read_window); - HotLoopScheduler::template GemmStagedScheduler<4>(); + // HotLoopScheduler::template GemmStagedScheduler<4>(); // QGrad Scale if constexpr(FmhaDropout::IsDropout) @@ -1010,13 +1184,15 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP } }(); - Policy::template PTFromGemm0CToGemm1A( - pt_reg_tensor, p_gemm); - auto dot_reg_tensor = load_tile(dot_lds_read_window); + // Policy::template PTFromGemm0CToGemm1A( + // pt_reg_tensor, p_gemm); + pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer(); + auto dot_reg_tensor = load_tile(dot_lds_read_window); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); - HotLoopScheduler::template GemmStagedScheduler<1>(); - __builtin_amdgcn_sched_barrier(0); + // HotLoopScheduler::template GemmStagedScheduler<1>(); + // __builtin_amdgcn_sched_barrier(0); // STAGE 4, OGrad@V Gemm2 auto dp_acc = SPGradBlockTileType{}; @@ -1025,8 +1201,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP dp_acc = gemm_2(do_reg_tensor, v_reg_tensor); - HotLoopScheduler::template GemmStagedScheduler<2>(); - __builtin_amdgcn_sched_barrier(0); + // HotLoopScheduler::template GemmStagedScheduler<2>(); + // __builtin_amdgcn_sched_barrier(0); // STAGE 5, P^T(PGrad^T - D) auto ds = SPGradBlockTileType{}; @@ -1085,8 +1261,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto ds_reg_tensor_next = decltype(ds_reg_tensor){}; move_tile_window(ds_lds_read_window, {0, kK4}); - HotLoopScheduler::template GemmStagedScheduler<3>(); - __builtin_amdgcn_sched_barrier(0); + // HotLoopScheduler::template GemmStagedScheduler<3>(); + // __builtin_amdgcn_sched_barrier(0); // STAGE 7, SGrad@K^T Gemm4 auto dq_acc = QGradBlockTileType{}; clear_tile(dq_acc); @@ -1107,8 +1283,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP } }); - HotLoopScheduler::template GemmStagedScheduler<4>(); - __builtin_amdgcn_sched_barrier(0); + // HotLoopScheduler::template GemmStagedScheduler<4>(); + // __builtin_amdgcn_sched_barrier(0); // Results Scale if constexpr(FmhaDropout::IsDropout) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 54dc8186e4..98513e9718 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -22,6 +22,8 @@ namespace ck_tile { struct BlockFmhaBwdPipelineDefaultPolicy { + static constexpr index_t kKVSeq0 = 64; + template CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() { @@ -44,7 +46,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}), Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}), false, - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>; + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true, + true>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy{}), Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}), false, - Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>; + Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true, + true>; using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy, sequence<2, 1>>{}); #elif 1 - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kNPerBlock = kKVSeq0; constexpr index_t kMWarps = 2; constexpr index_t kKWarps = 2; @@ -401,7 +405,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t K1 = 2; constexpr index_t kMPair = 2; constexpr index_t kMRepeat = 2; - constexpr index_t kMGroup = kNPerBlock/16; + constexpr index_t kMGroup = kNPerBlock / 16; return make_static_tile_distribution( tile_distribution_encoding, @@ -437,7 +441,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy sequence<1, 2>, sequence<0, 1>>{}); #elif 1 - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kNPerBlock = kKVSeq0; constexpr index_t kMWarps = 2; constexpr index_t kKWarps = 2; @@ -447,7 +451,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t K1 = 2; constexpr index_t kMPair = 2; constexpr index_t kMRepeat = 2; - constexpr index_t kMGroup = kNPerBlock/16; + constexpr index_t kMGroup = kNPerBlock / 16; return make_static_tile_distribution( tile_distribution_encoding, @@ -715,11 +719,11 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MWarp = 2; constexpr index_t KWarp = 2; constexpr index_t KRow = 2; + constexpr index_t MRow = 2; constexpr index_t KBit0 = 2; constexpr index_t KBit1 = 2; constexpr index_t KBit2 = 2; constexpr index_t KBit3 = 2; - constexpr index_t KBit4 = 2; constexpr index_t K1 = 2; constexpr index_t MPair = 2; constexpr index_t MRepeat = 2; @@ -734,6 +738,24 @@ struct BlockFmhaBwdPipelineDefaultPolicy // M = 2^4 // K = 2^7 + // constexpr index_t kMWarps = 2; + // constexpr index_t kKWarps = 2; + // constexpr index_t kKRow = 2; + // constexpr index_t kMRow = 2; + // constexpr index_t kRowsize = 16; + // constexpr index_t K1 = 2; + // constexpr index_t kMPair = 2; + // constexpr index_t kMRepeat = 2; + + // return make_static_tile_distribution( + // tile_distribution_encoding, + // tuple, + // sequence>, + // tuple, sequence<2, 1, 2>>, + // tuple, sequence<1, 2, 2>>, + // sequence<1, 1, 2>, + // sequence<1, 3, 3>>{}); + constexpr auto lds_16x128_block_desc_raw = make_naive_tensor_descriptor( make_tuple(number{}, number{}, @@ -742,39 +764,38 @@ struct BlockFmhaBwdPipelineDefaultPolicy number{}, number{}, number{}, - number{}, + number{}, number{}, number{}, number{}), make_tuple( - number{}, - number{}, - number{}, - number{}, - number{}, - number{}, - number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, number{}, number{}, number{}, number<1>{}), - number{}, + number{}, number<1>{}); constexpr auto lds_16x128_block_desc = transform_tensor_descriptor( lds_16x128_block_desc_raw, make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}, number{})), + make_tuple(number{}, number{}, number{},number{})), make_merge_transform_v3_division_mod(make_tuple(number{}, number{}, - number{}, number{}, number{}, number{}, number{}, number{}))), - make_tuple(sequence<3, 4, 1>{}, sequence<0, 2, 7, 8, 6, 5, 9, 10>{}), + make_tuple(sequence<3, 4, 7, 1>{}, sequence<0, 2, 8, 6, 5, 9, 10>{}), make_tuple(sequence<0>{}, sequence<1>{})); return lds_16x128_block_desc; @@ -829,7 +850,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy number{}, number{}, number<1>{}), - number{}, + number{}, number<1>{}); constexpr auto lds_16x128_trans_block_desc = transform_tensor_descriptor( @@ -849,6 +870,155 @@ struct BlockFmhaBwdPipelineDefaultPolicy return lds_16x128_trans_block_desc; } + CK_TILE_HOST_DEVICE static constexpr auto Make64x128LdsBlockDescriptor() + { + constexpr index_t MWarp = 2; + constexpr index_t KWarp = 2; + constexpr index_t KRow = 2; + constexpr index_t MRow = 2; + constexpr index_t KBit0 = 2; + constexpr index_t KBit1 = 2; + constexpr index_t KBit2 = 2; + constexpr index_t KBit3 = 2; + constexpr index_t K1 = 2; + constexpr index_t MPair = 2; + constexpr index_t MRepeat = 2; + constexpr index_t MGroup = 4; + + // K:HeadDim, M:Seq, 13 Dimensions Total + // I W T I V + // Total: 4*4*64*4*2 = 2^13 + + // I W I T W I T T T T T V + // 4 2 2 2 2 2 2 2 2 2 2 2 + // MGroup, KWarp, MPair, KRow, MWarp, MRepeat, KBit<1, 2, 4, 3, 0>, K1 + // M = 2^6 + // K = 2^7 + + constexpr auto lds_64x128_block_desc_raw = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple( + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto lds_64x128_block_desc = transform_tensor_descriptor( + lds_64x128_block_desc_raw, + make_tuple(make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}, number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<0, 4, 5, 8, 2>{}, sequence<1, 3, 9, 7, 6, 10, 11>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_64x128_block_desc; + } + + CK_TILE_HOST_DEVICE static constexpr auto Make64x128TransLdsBlockDescriptor() + { + constexpr index_t MWarp = 2; + constexpr index_t KWarp = 2; + constexpr index_t KRow = 2; + constexpr index_t MRow = 2; + constexpr index_t KGroup = 2; + constexpr index_t KBit0 = 2; + constexpr index_t KBit1 = 2; + constexpr index_t KBit2 = 2; + constexpr index_t K1 = 2; + constexpr index_t MPair = 2; + constexpr index_t MRepeat = 2; + constexpr index_t MGroup = 4; + + // K:HeadDim, M:Seq, 13 Dimensions Total + // I W T I V + // Total: 4* 4*64*4*2 = 2^13 + // I W I T W I T T T T T V + // 4 2 2 2 2 2 2 2 2 2 2 2 + // MGroup, Kwarp, K1, KRow, MWarp, MRepeat, , KGroup, MRow, MPair + // M = 2^6 + // K = 2^7 + + constexpr auto lds_64x128_trans_block_desc_raw = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + number{}, + // Padding + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto lds_64x128_trans_block_desc = transform_tensor_descriptor( + lds_64x128_trans_block_desc_raw, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{})), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<1, 3, 9, 7, 6, 8, 2>{}, sequence<0, 4, 5, 10, 11>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return lds_64x128_trans_block_desc; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() { @@ -1023,11 +1193,48 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() { +#if 0 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackK(); return MakeXLdsBlockDescriptor(); +#elif 1 + return Make64x128LdsBlockDescriptor(); +#endif + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); + + // constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + constexpr index_t kNPerBlock = kKVSeq0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto k_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode); + + return k_block_dstr; } template @@ -1065,12 +1272,48 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() { +#if 0 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kVPack = GetSmemKPackV(); return MakeXLdsBlockDescriptor(); +#elif 1 + return Make64x128LdsBlockDescriptor(); +#endif + } + + template + CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = kKVSeq0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto v_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); + + return v_block_dstr; } template @@ -1108,28 +1351,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t K1 = GetAlignmentK(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentK(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t kMWarps = 2; + constexpr index_t kKWarps = 2; + constexpr index_t kKRow = 2; + constexpr index_t kMRow = 2; + constexpr index_t kRowsize = 16; + constexpr index_t K1 = 2; + constexpr index_t kMPair = 2; + constexpr index_t kMRepeat = 2; + constexpr index_t kMGroup = kKVSeq0 / 16; return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + tuple, + sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 3, 2>>, + sequence<1, 1, 2, 1>, + sequence<0, 2, 3, 4>>{}); } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKLdsWriteBlockDescriptor() { +#if 0 // Hold all data constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; @@ -1138,6 +1383,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kKPackT = GetSmemKPackKT(); return MakeXTLdsBlockDescriptor(); +#elif 1 + return Make64x128TransLdsBlockDescriptor(); +#endif } template @@ -1156,6 +1404,38 @@ struct BlockFmhaBwdPipelineDefaultPolicy make_tuple(sequence<0>{}, sequence<1>{})); } + template + CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegSliceBlockDescriptor() + { + using BlockGemm = remove_cvref_t())>; + constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp(); + using WarpGemm = remove_cvref_t())>; + + constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); + constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); + + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kKPerBlock = kKVSeq0; + + constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); + constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; + + constexpr auto kt_block_outer_dstr_encoding = + tile_distribution_encoding, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 0>>{}; + + constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding( + kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{}); + + constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode); + + return kt_block_dstr; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor() { @@ -1261,7 +1541,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor() { - return Make16x128TransLdsBlockDescriptor(); + return Make16x128TransLdsBlockDescriptor(); } template @@ -2047,6 +2327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy } private: + // Read64Seq per tile for KV static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0; static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index e989d97188..11f54f3a04 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -25,6 +25,11 @@ using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl, 2>>; +using WarpGemmMfmaF16F16F32M16N16K32StaggeredK = WarpGemmImpl, + 2, + true>>; + using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl, 1>>; @@ -84,6 +89,11 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl, 2>>; +using WarpGemmMfmaBf16Bf16F32M16N16K32StaggeredK = WarpGemmImpl, + 2, + true>>; + using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl, 1>>; diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 685950d6d4..135463dd3b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -73,7 +73,7 @@ struct WarpGemmAtrributeMfma } }; -template +template struct WarpGemmAtrributeMfmaIterateK { static_assert(kKIter > 0, "wrong!"); @@ -102,79 +102,161 @@ struct WarpGemmAtrributeMfmaIterateK CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() { - if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + if constexpr(!StaggeredK) { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // each M blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } } - else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + else { - // each M blocks share the same data - return tile_distribution_encoding< - sequence, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) - { - // single block to multi-block thread mapping - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // each M blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } } } CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() { - if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + if constexpr(!StaggeredK) { - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // each N blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } } - else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + else { - // single block to multi-block thread mapping - return tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; - } - else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) - { - // each N blocks share the same data - return tile_distribution_encoding< - sequence, - tuple, - sequence>, - tuple>, - tuple>, - sequence<2>, - sequence<1>>{}; + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // single block to multi-block thread mapping + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // each N blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2>, + sequence<0, 2>>{}; + } } } diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 299e1fcd4b..402b9710f0 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -16,7 +16,8 @@ template + bool SwizzleA = false, + bool StaggeredK = false> struct WarpGemmMfmaDispatcher; // clang-format off @@ -35,6 +36,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N16K32StaggeredK; }; // bf16 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; @@ -51,6 +53,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N32K16TransposedCDistribution; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32StaggeredK; }; // fp8 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; @@ -72,7 +75,8 @@ template + bool SwizzleA = false, + bool StaggeredK = false> using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher::Type; + SwizzleA, + StaggeredK>::Type; } // namespace ck_tile