diff --git a/cmake/EnableCompilerWarnings.cmake b/cmake/EnableCompilerWarnings.cmake index fb2b38d688..bfdd3a7327 100644 --- a/cmake/EnableCompilerWarnings.cmake +++ b/cmake/EnableCompilerWarnings.cmake @@ -108,7 +108,7 @@ else() endif() list(APPEND CMAKE_COMPILER_WARNINGS -Wno-missing-field-initializers - -Wno-deprecated-declarations + # -Wno-deprecated-declarations ) endif() add_definitions(${CMAKE_COMPILER_WARNINGS}) diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py index 0dd3862b1a..275790a4bc 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_batch_decode.py @@ -106,6 +106,10 @@ static void run(const ck_tile::stream_config& s, fmha_batch_decode_args a) auto [kargs, grids] = fmha_batch_decode_create_kargs_and_grids(a); constexpr dim3 blocks = k_::BlockSize(); constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu; + printf("run into fmha_batch_decode_bf16_nlogits_nbias_nmask_lse_nsquant_pagedkv\\n"); + printf("blocks: %d, %d, %d\\n", blocks.x, blocks.y, blocks.z); + printf("grids: %d, %d, %d\\n", grids.x, grids.y, grids.z); + printf("kBlockPerCu: %d\\n", kBlockPerCu); ck_tile::make_kernel(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}}); }} }}; diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index be84842347..b4f0485996 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -62,6 +62,7 @@ #include "ck_tile/core/tensor/transpose_tile.hpp" #include "ck_tile/core/tensor/update_tile.hpp" #include "ck_tile/core/utility/bit_cast.hpp" +#include "ck_tile/core/utility/debug.hpp" #include "ck_tile/core/utility/env.hpp" #include "ck_tile/core/utility/functional.hpp" #include "ck_tile/core/utility/functional_with_tuple.hpp" diff --git a/include/ck_tile/core/tensor/tile_scatter_gather_debug.hpp b/include/ck_tile/core/tensor/tile_scatter_gather_debug.hpp index cfdf5cccef..c0f39ec43b 100644 --- a/include/ck_tile/core/tensor/tile_scatter_gather_debug.hpp +++ b/include/ck_tile/core/tensor/tile_scatter_gather_debug.hpp @@ -205,7 +205,7 @@ struct tile_scatter_gather_debug auto bottom_tensor_thread_coord = bottom_tensor_thread_coord_tmp; constexpr auto idx_diff_ys = - SFC_Ys::get_step_between(number<0>{}, number{}); + SFC_Ys::get_step_between(number<0>{}, number{}); // NumAccessPerCoord = 4 constexpr auto idx_diff_ps_ys = container_concat( generate_tuple([&](auto) { return number<0>{}; }, number{}), idx_diff_ys); @@ -338,14 +338,14 @@ struct tile_scatter_gather_debug constexpr auto idx_gather = idx_ys_start[number{}]; const auto page_offset = page_idx_[idx_gather]; // read from bottom tensor - auto idxx = bottom_tensor_thread_coord.get_index(); + auto idxx = bottom_tensor_thread_coord.get_index(); // 16 * 128 const vector_t vec_value = get_bottom_tensor_view().template get_vectorized_elements( bottom_tensor_thread_coord, page_offset, bool_constant{}); - // printf("bid %d tid %d coord_offset:%d %d %d, page_offset:%d v %f\n", blockIdx.x, threadIdx.x, - // bottom_tensor_thread_coord.get_offset(), idxx(I0), idxx(I1), page_offset, type_convert(vec_value.template get_as()[0]), type_convert(vec_value.template get_as()[4])); + // printf("bid %d tid %d coord_offset:%d %d %d, page_offset:%d v %f %f \n", blockIdx.x, threadIdx.x, + // bottom_tensor_thread_coord.get_offset(), idxx(I0), idxx(I1), page_offset, type_convert(vec_value.template get_as()[0]), type_convert(vec_value.template get_as()[0]), type_convert(vec_value.template get_as()[4])); #if 1 // write into distributed tensor static_for<0, Traits::ScalarPerVector, Traits::PackedSize>{}([&](auto j) { @@ -354,9 +354,9 @@ struct tile_scatter_gather_debug return jj == Traits::VectorDimY ? (idx_ys_start[jj] + j) : idx_ys_start[jj]; }, - number{}); + number{}); // idx_ys = tuple, ck_tile::constant<2>, ck_tile::constant<4>> - constexpr index_t d = + constexpr index_t d = tile_dstr.get_ys_to_d_descriptor().calculate_offset(idx_ys) / Traits::PackedSize; diff --git a/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp index 7272edad04..23ac1f0bed 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_batch_decode_kernel.hpp @@ -738,8 +738,8 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel { const auto v_dram_naive = make_naive_tensor_view( v_ptr, - make_tuple(kargs.num_total_pages * kargs.page_block_size, kargs.hdim_v), - make_tuple(kargs.stride_v, 1), + make_tuple(kargs.num_total_pages / 16, kargs.hdim_v, 16 / 8, 8), + make_tuple(kargs.hdim_v * 16, 16, 8, 1), number{}, number<1>{}); @@ -747,15 +747,15 @@ struct FmhaBatchDecodeWithPagedKVCacheKernel v_dram_naive, make_tuple( make_pass_through_transform(kargs.hdim_v), - make_pass_through_transform(kargs.num_total_pages * kargs.page_block_size)), - make_tuple(sequence<1>{}, sequence<0>{}), + make_merge_transform(make_tuple(kargs.num_total_pages / 16, 16 / 8, 8))), + make_tuple(sequence<1>{}, sequence<0, 2, 3>{}), make_tuple(sequence<0>{}, sequence<1>{})); - constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; + // constexpr bool kPadSeqLenK_ = kUseAsyncCopy ? kPadSeqLenK : true; return pad_tensor_view( v_dram_transposed, make_tuple(number{}, number{}), - sequence{}); + sequence{}); } else { diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs.hpp index 325dd5128d..fa9bf1437e 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs.hpp @@ -219,7 +219,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemmPreshuffled(); - constexpr auto gemm_1 = Policy::template GetKVBlockGemm(); + constexpr auto gemm_1 = Policy::template GetKVBlockGemmPreshuffled(); auto q_dram_window = make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), @@ -259,7 +259,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( q_origin.at(number<0>{}), number{}, number{}, num_splits, i_split); - const index_t num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); + const index_t num_total_loop = + integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0); // [64, 0, 64 -> 1] // check early exit if no work to do if(num_total_loop <= 0) @@ -297,22 +298,24 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N Policy::template MakeBiasDramTileDistribution()); - auto v_dist = Policy::template MakeVDramTileDistribution(); - auto v_coord = v_dist.calculate_index(); - const auto VPageIndexDim = I1; - using VDstrEncode = typename decltype(v_dist)::DstrEncode; - constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I3]; + auto v_dist = Policy::template MakeVDramTileDistributionPreshuffled(); + auto v_coord = v_dist.calculate_index(); + const auto VPageIndexDim = I1; + using VDstrEncode = typename decltype(v_dist)::DstrEncode; + constexpr index_t V_KRepeat = VDstrEncode::hs_lengthss_[I1][I0]; statically_indexed_array v_offsets; static_for<0, V_KRepeat, 1>{}([&](auto k0) { v_offsets[k0] = kv_page_indices[v_coord[VPageIndexDim] + k0.value] * stride_v; }); - auto v_dram_window = - make_tile_scatter_gather(v_dram_block_window_tmp.get_bottom_tensor_view(), - v_dram_block_window_tmp.get_window_lengths(), - {0, seqlen_k_start}, // TODO: hdim split? - v_dist, - v_offsets, - VPageIndexDim); + // v_dram_block_window_tmp.get_bottom_tensor_view(): see dram_bottom_tensor_view_record + auto v_dram_window = make_tile_scatter_gather_debug( + v_dram_block_window_tmp.get_bottom_tensor_view(), + v_dram_block_window_tmp + .get_window_lengths(), // tuple, ck_tile::constant<64>> + {0, seqlen_k_start}, // TODO: hdim split? + v_dist, + v_offsets, + VPageIndexDim); // store Q into LDS __builtin_amdgcn_sched_barrier(0); @@ -343,21 +346,25 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS static_assert(1 <= k1_loops); auto k_dram_window = [&] { - auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled(); - auto k_coord = k_dist.calculate_index(); - using KDstrEncode = typename decltype(k_dist)::DstrEncode; + auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled(); + auto k_coord = k_dist.calculate_index(); + using KDstrEncode = typename decltype(k_dist)::DstrEncode; // constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I1]; constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; statically_indexed_array k_offsets; static_for<0, NRepeat, 1>{}([&](auto n0) { - k_offsets[n0] = kv_page_indices[k_coord[0] + kN0 / NRepeat * n0.value] / 16 * stride_k; + k_offsets[n0] = + kv_page_indices[k_coord[0] + kN0 / NRepeat * n0.value] / 16 * stride_k; + // printf("threadIdx.x %d, k_offsets[%d, %d] = %d\n", threadIdx.x, k_coord[0], + // k_coord[1], k_offsets[n0]); }); - return make_tile_scatter_gather_debug(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - k_dist, - k_offsets); // K DRAM tile window for + return make_tile_scatter_gather_debug( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), // [64, 128] + k_dram_block_window.get_window_origin(), // + k_dist, + k_offsets); // K DRAM tile window for }(); // load the first tile of the first iteration and store to LDS @@ -379,7 +386,8 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS // sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { // constexpr auto i_j_idx = make_tuple(idx0, idx1); // if(blockIdx.x==0) - // printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x, type_convert(k_block_tile(i_j_idx))); + // printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x, + // type_convert(k_block_tile(i_j_idx))); // }); // }); if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) @@ -409,19 +417,21 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS // store_tile( // k_lds_window, // tile_elementwise_in(k_element_func, k_block_tile)); // LDS write i + 1 - k_block_tile = load_tile(k_dram_window); // global read i + 2 + k_block_tile = load_tile(k_dram_window); // global read i + 2 }); - + // sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) { // sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) { // constexpr auto i_j_idx = make_tuple(idx0, idx1); // if(blockIdx.x==0) - // printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x, type_convert(k_block_tile(i_j_idx))); + // printf("bid %d tid %d %f\n", blockIdx.x, threadIdx.x, + // type_convert(k_block_tile(i_j_idx))); // }); // }); } - const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + // const auto v_prefetch = load_tile(v_dram_window); // prefetch load v tile + const auto v_block_tile = load_tile(v_dram_window); static_for<0, V_KRepeat, 1>{}([&](auto k0) { v_offsets[k0] = kv_page_indices[kK1 + v_coord[VPageIndexDim] + k0.value] * stride_v; @@ -488,8 +498,7 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS { s_acc = tile_elementwise_in(s_acc_element_func, s_acc); #if !CK_TILE_FMHA_FWD_FAST_EXP2 - tile_elementwise_inout([&scale_s](auto& x) { - x = x * scale_s; }, s_acc); + tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc); #else if constexpr(kHasLogitsSoftCap) { @@ -545,22 +554,23 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS move_tile_window(k_dram_block_window, {kN0, 0}); k_dram_window = [&] { - auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled(); - auto k_coord = k_dist.calculate_index(); - using KDstrEncode = typename decltype(k_dist)::DstrEncode; + auto k_dist = Policy::template MakeKDramTileDistributionPreshuffled(); + auto k_coord = k_dist.calculate_index(); + using KDstrEncode = typename decltype(k_dist)::DstrEncode; constexpr index_t NRepeat = KDstrEncode::hs_lengthss_[I0][I0]; statically_indexed_array k_offsets; static_for<0, NRepeat, 1>{}([&](auto n0) { k_offsets[n0] = - (kv_page_indices + kN0)[k_coord[0] + kN0 / NRepeat * n0.value] / 16 * + (kv_page_indices + kN0)[k_coord[0] + kN0 / NRepeat * n0.value] / 16 * stride_k; }); - return make_tile_scatter_gather_debug(k_dram_block_window.get_bottom_tensor_view(), - k_dram_block_window.get_window_lengths(), - k_dram_block_window.get_window_origin(), - k_dist, - k_offsets); // K DRAM tile window for + return make_tile_scatter_gather_debug( + k_dram_block_window.get_bottom_tensor_view(), + k_dram_block_window.get_window_lengths(), + k_dram_block_window.get_window_origin(), + k_dist, + k_offsets); // K DRAM tile window for }(); // laod the first tile of the first iteration and store to LDS @@ -682,28 +692,30 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS }); }); - block_sync_lds(); - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v_prefetch); - store_tile( - v_lds_window, - tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch - } - move_tile_window(v_dram_window, {0, kK1}); + // block_sync_lds(); + // if constexpr(std::is_same_v) + // { + // auto v_shuffle_tmp = make_static_distributed_tensor( + // Policy::template MakeShuffledVRegBlockDescriptor()); + // shuffle_tile(v_shuffle_tmp, v_prefetch); + // store_tile( + // v_lds_window, + // tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch + // } + // else + // { + // store_tile(v_lds_window, + // tile_elementwise_in(v_element_func, v_prefetch)); // store the + // prefetch + // } + + move_tile_window(v_dram_window, {0, kK1}); // STAGE 3, KV gemm if constexpr(k1_loops > 1) { static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { - const auto v = load_tile(v_dram_window); // load next v + // const auto v = load_tile(v_dram_window); // load next v static_for<0, V_KRepeat, 1>{}([&](auto k0) { v_offsets[k0] = kv_page_indices[kK1 * 2 + i_k1.value * kK1 + @@ -713,35 +725,152 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVS v_dram_window.update_page_idx(v_offsets); block_sync_lds(); - gemm_1(o_acc, - get_slice_tile( - p, sequence<0, i_k1 * kK1>{}, sequence{}), - v_lds_window); + + // gemm_1(o_acc, + // get_slice_tile( + // p, sequence<0, i_k1 * kK1>{}, sequence{}), + // v_block_tile); block_sync_lds(); - if constexpr(std::is_same_v) - { - auto v_shuffle_tmp = make_static_distributed_tensor( - Policy::template MakeShuffledVRegBlockDescriptor()); - shuffle_tile(v_shuffle_tmp, v); - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, - v_shuffle_tmp)); // store the prefetch - } - else - { - store_tile(v_lds_window, - tile_elementwise_in(v_element_func, v)); // store next v - } + // if constexpr(std::is_same_v) + // { + // auto v_shuffle_tmp = make_static_distributed_tensor( + // Policy::template MakeShuffledVRegBlockDescriptor()); + // shuffle_tile(v_shuffle_tmp, v); + // store_tile(v_lds_window, + // tile_elementwise_in(v_element_func, + // v_shuffle_tmp)); // store the prefetch + // } + // else + // { + // store_tile(v_lds_window, + // tile_elementwise_in(v_element_func, v)); // store next v + // } move_tile_window(v_dram_window, {0, kK1}); + const auto v_block_tile = load_tile(v_dram_window); }); } // tail { block_sync_lds(); + auto temp = get_slice_tile( + p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}); + + // using a = ck_tile::static_distributed_tensor< + // unsigned short, + // ck_tile::tile_distribution< + // ck_tile::tensor_adaptor< + // ck_tile::tuple< + // ck_tile::replicate>>, + // ck_tile::unmerge, + // ck_tile::constant<1>, + // ck_tile::constant<16>>, + // false>, + // ck_tile::unmerge, + // ck_tile::constant<4>, + // ck_tile::constant<4>, + // ck_tile::constant<4>>, + // false>, + // ck_tile::merge_v2_magic_division< + // ck_tile::tuple, ck_tile::constant<4>>>, + // ck_tile::merge_v2_magic_division< + // ck_tile::tuple, + // ck_tile::constant<16>>>>, + // ck_tile::tuple, + // ck_tile::sequence<0>, + // ck_tile::sequence<1>, + // ck_tile::sequence<4, 2>, + // ck_tile::sequence<8, 5>>, + // ck_tile::tuple, + // ck_tile::sequence<3, 4, 5>, + // ck_tile::sequence<6, 7, 8, 9>, + // ck_tile::sequence<10>, + // ck_tile::sequence<11>>, + // ck_tile::sequence<0, 1>, + // ck_tile::sequence<10, 11, 3, 6, 7, 9>>, + // ck_tile::tensor_descriptor< + // ck_tile::tuple, + // ck_tile::constant<1>, + // ck_tile::constant<4>, + // ck_tile::constant<4>>, + // false>>, + // ck_tile::tuple>, + // ck_tile::tuple>, + // ck_tile::sequence<1, 2, 3, 4>, + // ck_tile::constant<16>, + // ck_tile::sequence<-1, -1, -1, -1, -1>, + // ck_tile::sequence<-1, -1, -1, -1, -1>>, + // ck_tile::tile_distribution_encoding< + // ck_tile::sequence<4>, + // ck_tile::tuple, + // ck_tile::sequence<1, 4, 4, 4>>, + // ck_tile::tuple, ck_tile::sequence<2, 1>>, + // ck_tile::tuple, ck_tile::sequence<2, 2>>, + // ck_tile::sequence<1, 2, 2, 2>, + // ck_tile::sequence<0, 0, 1, 3>>, + // ck_tile::detail::tile_distribution_detail< + // ck_tile::tuple, + // ck_tile::sequence<3, 4, 5>, + // ck_tile::sequence<6, 7, 8, 9>>>>>; + + // using b = ck_tile::static_distributed_tensor< + // unsigned short, + // ck_tile::tile_distribution< + // ck_tile::tensor_adaptor< + // ck_tile::tuple< + // ck_tile::replicate>>, + // ck_tile::unmerge, + // ck_tile::constant<2>, + // ck_tile::constant<16>>, + // false>, + // ck_tile::unmerge, + // ck_tile::constant<2>, + // ck_tile::constant<2>, + // ck_tile::constant<8>>, + // false>, + // ck_tile::merge_v2_magic_division>>, + // ck_tile::merge_v2_magic_division< + // ck_tile::tuple, + // ck_tile::constant<2>, + // ck_tile::constant<16>>>>, + // ck_tile::tuple, + // ck_tile::sequence<0>, + // ck_tile::sequence<1>, + // ck_tile::sequence<3>, + // ck_tile::sequence<7, 8, 5>>, + // ck_tile::tuple, + // ck_tile::sequence<3, 4, 5>, + // ck_tile::sequence<6, 7, 8, 9>, + // ck_tile::sequence<10>, + // ck_tile::sequence<11>>, + // ck_tile::sequence<0, 1>, + // ck_tile::sequence<10, 11, 6, 4, 9>>, + // ck_tile::tensor_descriptor< + // ck_tile::tuple, + // ck_tile::constant<2>, + // ck_tile::constant<8>>, + // false>>, + // ck_tile::tuple>, + // ck_tile::tuple>, + // ck_tile::sequence<1, 2, 3>, + // ck_tile::constant<32>, + // ck_tile::sequence<-1, -1, -1, -1>, + // ck_tile::sequence<-1, -1, -1, -1>>, + // ck_tile::tile_distribution_encoding< + // ck_tile::sequence<1>, + // ck_tile::tuple, ck_tile::sequence<2, 2, 2, 8>>, + // ck_tile::tuple, ck_tile::sequence<2, 2, 1>>, + // ck_tile::tuple, ck_tile::sequence<1, 2, 2>>, + // ck_tile::sequence<2, 1, 2>, + // ck_tile::sequence<0, 1, 3>>, + // ck_tile::detail::tile_distribution_detail< + // ck_tile::tuple, + // ck_tile::sequence<3, 4, 5>, + // ck_tile::sequence<6, 7, 8, 9>>>>> + gemm_1(o_acc, get_slice_tile( p, sequence<0, (k1_loops - 1) * kK1>{}, sequence{}), - v_lds_window); + v_block_tile); block_sync_lds(); } kv_page_indices += kN0; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs_default_policy.hpp index 37ed4382e5..e5a0986f81 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_batch_decode_pipeline_qr_ks_vs_default_policy.hpp @@ -8,6 +8,7 @@ #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp" +#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v3.hpp" namespace ck_tile { @@ -59,20 +60,20 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy static_assert(0 < ElemPerThread); constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize); - constexpr index_t KPerThread = kMaxVecLoad; - constexpr index_t KThreads = kKPerBlock / KPerThread; - constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; - constexpr index_t NumWarps = kBlockSize / get_warp_size(); - constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); + constexpr index_t KPerThread = kMaxVecLoad; // 8 + constexpr index_t KThreads = kKPerBlock / KPerThread; // 16 + constexpr index_t MThreadPerWarp = get_warp_size() / KThreads; // 4 + constexpr index_t NumWarps = kBlockSize / get_warp_size(); // 4 + constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps); // 1 return make_static_tile_distribution( tile_distribution_encoding, - tuple, - sequence>, + tuple, // 1, 4, 4 + sequence>, // 16, 8 tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, + tuple, sequence<2, 0>>, // NumWarps, MThreadPerWarp, KThreads 4, 4, 16 sequence<1, 2>, - sequence<0, 1>>{}); + sequence<0, 1>>{}); // MPerThread, KPerThread 1, 8 } // template @@ -227,26 +228,58 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; // 64 + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; // 128 constexpr index_t MaxVectorSize = 16 / sizeof(KDataType); - constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; + constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize; // 32 - constexpr index_t K2 = min(MaxVectorSize, ElemPerThread); //8 + constexpr index_t K2 = min(MaxVectorSize, ElemPerThread); // 8 constexpr index_t K1 = 4; - constexpr index_t K0 = kKPerBlock / K1 / K2; // 2 - constexpr index_t N2 = 16; // 8 - constexpr index_t N1 = kBlockSize / get_warp_size(); - constexpr index_t N0 = kNPerBlock / (N2 * N1); + constexpr index_t K0 = kKPerBlock / K1 / K2; // 4 + constexpr index_t N2 = 16; + constexpr index_t N1 = kBlockSize / get_warp_size(); // 4 + constexpr index_t N0 = kNPerBlock / (N2 * N1); // 1 return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, sequence>, // 1, 4, 16, 4, 4, 8 tuple, sequence<2, 1>>, - tuple, sequence<1, 2>>, + tuple, sequence<1, 2>>, // N1, K1, N2 : 4, 4, 16 sequence<1, 2, 2>, - sequence<0, 0, 2>>{}); + sequence<0, 0, 2>>{}); // N0, K0, K2 : 1, 4, 8 + } + template + CK_TILE_DEVICE static constexpr auto MakeVDramTileDistributionPreshuffled() + { + using VLayout = remove_cvref_t; + using VDataType = remove_cvref_t; + + constexpr index_t kBlockSize = Problem::kBlockSize; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128 + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64 + + static_assert(std::is_same_v, "wrong"); + + constexpr index_t N2 = 16; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; // 32 + constexpr index_t MaxVectorSize = 16 / sizeof(VDataType); + constexpr index_t K3 = min(MaxVectorSize, total_pixels); // 8 + constexpr index_t K1 = 2; + constexpr index_t K2 = 2; + constexpr index_t K0 = kKPerBlock / K1 / K2 / K3; // 2 + constexpr index_t N0 = kBlockSize / get_warp_size(); // 4 + constexpr index_t N1 = kNPerBlock / (N2 * N0); // 2 + + static_assert(kKPerBlock == K0 * K1 * K2 * K3); + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, sequence>, // 4, 2, 16, 2, 2, 2, 8 + tuple, sequence<2, 2, 1>>, + tuple, sequence<1, 2, 2>>, // N0, K1, K2, N2 : 4, 2, 2, 16 + sequence<2, 1, 2>, + sequence<0, 1, 3>>{}); // K0, N1, K3 : 2, 2, 8 + } template CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution() @@ -320,6 +353,67 @@ struct BlockFmhaBatchDecodeWithPagedKVCachePipelineQRKSVSDefaultPolicy static_assert(1 < Problem::kNumGemm0Warps); return BlockGemmARegBRegCRegV2{}; } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemmPreshuffled() + { + using GemmProblem = + BlockGemmProblem, + typename Problem::BlockFmhaShape::Gemm1BlockWarps, + typename Problem::BlockFmhaShape::Gemm1WarpTile>>; + + constexpr auto warp_gemm = []() { + constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}); + static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32); + + if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaF16F16F32M4N64K16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + if constexpr(WarpGemmM == 32) + return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{}; + else if constexpr(WarpGemmM == 16) + return WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution{}; + else // WarpGemmM == 4 + return WarpGemmMfmaBf16Bf16F32M4N64K16{}; + } + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { + static_assert(WarpGemmM == 32); + + // TODO: hard coded here. Otherwise, it may incorrect result + return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<>{}; + } // TODO - bf8_t + }(); + + using BlockGemmPolicy = + BlockGemmARegBSmemCRegV2CustomPolicy; + static_assert(1 < Problem::kNumGemm1Warps); + return BlockGemmARegBRegCRegV3{}; + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp index b27e9a5baa..af72765e70 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp @@ -152,7 +152,7 @@ struct BlockFmhaPipelineQXCustomPolicy constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType); // this should align with MakeQDramTileDistribution() - constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; + constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize; // 16 * 128 / 256 = 8 static_assert(0 < ElemPerThread); return min(ElemPerThread, MaxVectorSize); } @@ -729,16 +729,16 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy, - tuple, sequence>, + tuple, sequence>, // 4, 4, 8, 8, 8 tuple, sequence<1, 2>>, - tuple, sequence<2, 0>>, + tuple, sequence<2, 0>>, // N1, N2, K0: 4, 8, 8 sequence<1, 2>, - sequence<0, 1>>{}); + sequence<0, 1>>{}); // N0, K1: 4,8 } else { @@ -778,32 +778,32 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy; constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1; // 128 + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; // 64 if constexpr(std::is_same_v) { - constexpr index_t N1 = GetAlignmentV(); - constexpr index_t N0 = kNPerBlock / N1; // P + constexpr index_t N1 = GetAlignmentV(); // 8 + constexpr index_t N0 = kNPerBlock / N1; // 16 - constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; + constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; // 32 static_assert(total_pixels % N1 == 0); // TODO: this is not always true? - constexpr index_t K3 = total_pixels / N1; - constexpr index_t kKPack = GetSmemKPackV(); + constexpr index_t K3 = total_pixels / N1; // 4 + constexpr index_t kKPack = GetSmemKPackV(); // 8 static_assert(kKPack % K3 == 0); - constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave + constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave // 2 if constexpr(get_warp_size() % (K2 * N0) == 0) { - constexpr index_t K1 = get_warp_size() / (K2 * N0); - constexpr index_t K0 = kBlockSize / get_warp_size(); + constexpr index_t K1 = get_warp_size() / (K2 * N0); // 2 + constexpr index_t K0 = kBlockSize / get_warp_size(); // 4 static_assert(kKPerBlock == K0 * K1 * K2 * K3); return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, + tuple, sequence>, // 16, 8, 4, 2, 2, 4 tuple, sequence<2, 1, 2>>, - tuple, sequence<1, 0, 2>>, + tuple, sequence<1, 0, 2>>, // K0, K1, N0, K2: 4, 2, 16, 2 sequence<2, 1>, - sequence<3, 1>>{}); + sequence<3, 1>>{}); // K3, N1: 4, 8 } else { diff --git a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp index 5b421c210f..a4690500d1 100644 --- a/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp +++ b/include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp @@ -44,15 +44,16 @@ struct BlockGemmARegBRegCRegV2 "wrong!"); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp(); + // const const tuple, 2>>, int, int> using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); + constexpr index_t MWarp = config.template at<1>(); // 1 + constexpr index_t NWarp = config.template at<2>(); // 4 constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // WG::kM, kN, kK: 16, 16, 32 constexpr index_t NPerBlockPerIter = NPerBlock / NIterPerWarp; constexpr index_t KPerBlockPerIter = KPerBlock / KIterPerWarp; @@ -60,13 +61,18 @@ struct BlockGemmARegBRegCRegV2 const index_t iNWarp = get_warp_id() % NWarp; // if(threadIdx.x%64==0) // printf("tid %d %d %d %d %d KIterPerWarp %d\n",threadIdx.x, MPerBlock, NPerBlock, NIterPerWarp, iNWarp, KIterPerWarp); + + // tid 0 16 64 1 0 KIterPerWarp 4 + // tid 64 16 64 1 1 KIterPerWarp 4 + // tid 128 16 64 1 2 KIterPerWarp 4 + // tid 192 16 64 1 3 KIterPerWarp 4 constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding< sequence<>, tuple, sequence>, tuple>, - tuple>, + tuple>, // MWarp, NWarp 1, 4 sequence<1, 2>, - sequence<0, 0>>{}; + sequence<0, 0>>{}; // MIterPerWarp, NIterPerWarp, 1, 1 constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); @@ -123,9 +129,10 @@ struct BlockGemmARegBRegCRegV2 // printf("get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size()); // if(threadIdx.x==0) - // printf("get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size()); + // printf("b get_thread_buffer_size %d\n", b_warp_tensor.get_thread_buffer_size()); // auto &xx = b_warp_tensor.get_thread_buffer(); // printf("bid %d tid %d b %f %f %f %f\n", blockIdx.x, threadIdx.x, type_convert(xx[number<0>{}]), type_convert(xx[number<2>{}]), type_convert(xx[number<4>{}]), type_convert(xx[number<6>{}])); + static_for<0, MIterPerWarp, 1>{}([&](auto mIter) { // read A warp tensor from A block tensor AWarpTensor a_warp_tensor; @@ -196,19 +203,19 @@ struct BlockGemmARegBRegCRegV2 using WG = remove_cvref_t())>; - constexpr index_t MWarp = config.template at<1>(); - constexpr index_t NWarp = config.template at<2>(); + constexpr index_t MWarp = config.template at<1>(); // 1 + constexpr index_t NWarp = config.template at<2>(); // 4 - constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); - constexpr index_t KIterPerWarp = KPerBlock / WG::kK; + constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); // 1 + constexpr index_t KIterPerWarp = KPerBlock / WG::kK; // 4 constexpr auto a_block_outer_dstr_encoding = - tile_distribution_encoding, - tuple, sequence>, - tuple>, + tile_distribution_encoding, // 4 + tuple, sequence>, // 1, 1, 4 tuple>, + tuple>, // MWarp, NWarp 1, 4 sequence<1, 2>, - sequence<0, 0>>{}; + sequence<0, 0>>{}; // MIterPerWarp, KIterPerWarp 1, 4 constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index 93ccdb5f57..4b61e86652 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -478,16 +478,32 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() { + // static constexpr index_t kM = 16; + // static constexpr index_t kN = 16; + // static constexpr index_t kK = 16; + // static constexpr index_t kAMBlock = 1; + // static constexpr index_t kBNBlock = 1; + + // static constexpr index_t kAMLane = 16; + // static constexpr index_t kBNLane = 16; + // static constexpr index_t kABKLane = 4; + // static constexpr index_t kABKPerLane = 4; + + // static constexpr index_t kCMLane = 4; + // static constexpr index_t kCNLane = 16; + // static constexpr index_t kCM0PerLane = 1; + // static constexpr index_t kCM1PerLane = 4; + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) { return tile_distribution_encoding< sequence<>, - tuple, - sequence>, + tuple, // 32/16 warp shape N + sequence>, // 2/4, 8 instruction data layout tuple>, - tuple>, + tuple>, // 4, 16 sequence<2>, - sequence<1>>{}; + sequence<1>>{}; // 8 } else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) {