diff --git a/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp b/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp index d58b334e61..2c1264eee3 100644 --- a/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp +++ b/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp @@ -38,8 +38,7 @@ using fmha_bwd_convert_dq_0 = using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel; -using convert_dq_trait_0 = - fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>; +using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, true, false, false>; template <> void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, @@ -67,6 +66,7 @@ using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; using fmha_warp_tile0_0 = ck_tile::sequence<16, 16, 32>; using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; +using fmha_warp_tile2_0 = ck_tile::sequence<16, 32, 16>; // TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape // G0&G2 -> GSdP @@ -76,16 +76,16 @@ using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits(&(q_host(0, 0, iK, iM))))); + // } + // printf("\n"); + // } + ck_tile::DeviceMem q_buf(q_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem k_buf(k_host.get_element_space_size_in_bytes()); ck_tile::DeviceMem v_buf(v_host.get_element_space_size_in_bytes()); diff --git a/include/ck_tile/core/container/sequence.hpp b/include/ck_tile/core/container/sequence.hpp index 4fcea9642d..5c5347ec18 100644 --- a/include/ck_tile/core/container/sequence.hpp +++ b/include/ck_tile/core/container/sequence.hpp @@ -1081,8 +1081,26 @@ struct sorted_sequence_histogram, sequence> { h.template at() += 1; } + else + { + sorted_sequence_histogram, sequence>{}(h); + } } }; + +template +struct sorted_sequence_histogram, sequence> +{ + template + constexpr auto operator()(Histogram& h) + { + if constexpr(x < r) + { + h.template at() += 1; + } + } +}; + } // namespace detail template diff --git a/include/ck_tile/core/tensor/shuffle_tile.hpp b/include/ck_tile/core/tensor/shuffle_tile.hpp index 55e3274cde..5418375180 100644 --- a/include/ck_tile/core/tensor/shuffle_tile.hpp +++ b/include/ck_tile/core/tensor/shuffle_tile.hpp @@ -46,23 +46,22 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT return rh_major_minor_to_y_; }; - constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{}); - constexpr auto rh_major_minor_to_y_out = get_rh_major_minor_to_y(OutTensor{}); + constexpr auto rh_major_minor_to_y_in = get_rh_major_minor_to_y(InTensor{}); - constexpr auto y_dim_out_to_in = [&] { - map y_dim_out_to_in_; - - for(const auto& [rh_major_minor, y_out] : rh_major_minor_to_y_out) - { - y_dim_out_to_in_(y_out) = rh_major_minor_to_y_in[rh_major_minor]; - } - - return y_dim_out_to_in_; - }(); - - // constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y(); + using OutDstrEncode = typename decltype(out_tensor.get_tile_distribution())::DstrEncode; + // using InDstrEncode = typename decltype(in_tensor.get_tile_distribution())::DstrEncode; + + constexpr auto y_dim_out_to_in = generate_sequence_v2( + [&](auto i) constexpr { + constexpr index_t rh_major_out = OutDstrEncode::ys_to_rhs_major_[i]; + constexpr index_t rh_minor_out = OutDstrEncode::ys_to_rhs_minor_[i]; + + return number{}; + }, + number{}); + constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths()); // input and output vector dim in the order of input Y dims @@ -128,7 +127,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT // set output vectors static_for<0, num_vec_out, 1>{}([&](auto i) { - constexpr auto idx_y_out_tmp = generate_array( + constexpr auto idx_y_out_tmp = generate_tuple( [&](auto ii) { return ii == y_dim_vec_in ? idx_y_start[ii] + i : idx_y_start[ii]; }, number{}); diff --git a/include/ck_tile/core/tensor/slice_tile.hpp b/include/ck_tile/core/tensor/slice_tile.hpp index 7a4ba2eb79..d51b4c92fb 100644 --- a/include/ck_tile/core/tensor/slice_tile.hpp +++ b/include/ck_tile/core/tensor/slice_tile.hpp @@ -86,7 +86,7 @@ set_slice_tile(static_distributed_tensor, "wrong!"); - dst_tile.SetSlicedThreadData(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); + dst_tile.set_y_sliced_thread_data(sliced_y_origins, sliced_y_lengths, src_tile.get_thread_buffer()); } } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index f397fe32f9..509a5c4a25 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -258,6 +258,51 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP block_sync_lds(); auto v_reg_tensor = load_tile(v_lds_read_window); + +#if 0 + constexpr auto kSeq0 = 64; + // Looped data loading + static_for<0, kN0 / kSeq0, 1>{}([&](auto i_n0) { + auto k_block_tile = load_tile(k_dram_window); + move_tile_window(k_dram_window, {kSeq0, 0}); + + store_tile(k_lds_write_window, k_block_tile); + + shuffle_distributed_tensor(kt_block_tile, k_block_tile); + store_tile(kt_lds_write_window, kt_block_tile); + + block_sync_lds(); + + auto k_reg_tensor_slice = load_tile(k_lds_read_window); + set_slice_tile(k_reg_tensor, + k_reg_tensor_slice, + Sequence{}, + Sequence<(i_n0+1)*kSeq0, kQKHeaddim>{}); + + auto kt_reg_tensor_slice = load_tile(kt_lds_read_window); + set_slice_tile(kt_reg_tensor, + kt_reg_tensor_slice, + Sequence<0, i_n0*kSeq0>{}, + Sequence{}); + block_sync_lds(); + }); + + static_for<0, kN0 / kSeq0, 1>{}([&](auto i_n0) { + auto v_block_tile = load_tile(v_dram_window); + move_tile_window(v_dram_window, {kSeq0, 0}); + + store_tile(v_lds_write_window, v_block_tile); + + block_sync_lds(); + + auto v_reg_tensor_slice = load_tile(v_lds_read_window); + set_slice_tile(v_reg_tensor, + v_reg_tensor_slice, + Sequence{}, + Sequence<(i_n0+1)*kSeq0, kVHeaddim>{}); + block_sync_lds(); + }); +#endif //---------------------------- Loop Load in ----------------------------// // Q: HBM ->Reg ->LDS auto q_dram_window = @@ -738,12 +783,61 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP // STAGE 6, SGrad^T@Q^T Gemm3 const auto ds_gemm = cast_tile(ds); - Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + // Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, + // ds_gemm); + dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); + // if(get_block_1d_id()==0 && get_thread_local_1d_id()<64 &&i_total_loops==0){ + // printf("Tid: %02d, Qt: %04x %04x %04x %04x %04x %04x %04x %04x DsT: %04x %04x + // %04x %04x %04x %04x %04x %04x dk_acc: %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf %.4lf + // %.4lf\n", + // get_thread_local_1d_id(), + // *(reinterpret_cast(&(qt_reg_tensor.get_thread_buffer()[number<0>{}]))), + // *(reinterpret_cast(&(qt_reg_tensor.get_thread_buffer()[number<1>{}]))), + // *(reinterpret_cast(&(qt_reg_tensor.get_thread_buffer()[number<2>{}]))), + // *(reinterpret_cast(&(qt_reg_tensor.get_thread_buffer()[number<3>{}]))), + // *(reinterpret_cast(&(qt_reg_tensor.get_thread_buffer()[number<4>{}]))), + // *(reinterpret_cast(&(qt_reg_tensor.get_thread_buffer()[number<5>{}]))), + // *(reinterpret_cast(&(qt_reg_tensor.get_thread_buffer()[number<6>{}]))), + // *(reinterpret_cast(&(qt_reg_tensor.get_thread_buffer()[number<7>{}]))), + // *(reinterpret_cast(&(dst_reg_tensor.get_thread_buffer()[number<0>{}]))), + // *(reinterpret_cast(&(dst_reg_tensor.get_thread_buffer()[number<1>{}]))), + // *(reinterpret_cast(&(dst_reg_tensor.get_thread_buffer()[number<2>{}]))), + // *(reinterpret_cast(&(dst_reg_tensor.get_thread_buffer()[number<3>{}]))), + // *(reinterpret_cast(&(dst_reg_tensor.get_thread_buffer()[number<4>{}]))), + // *(reinterpret_cast(&(dst_reg_tensor.get_thread_buffer()[number<5>{}]))), + // *(reinterpret_cast(&(dst_reg_tensor.get_thread_buffer()[number<6>{}]))), + // *(reinterpret_cast(&(dst_reg_tensor.get_thread_buffer()[number<7>{}]))), + // dk_acc.get_thread_buffer()[number<0>{}], + // dk_acc.get_thread_buffer()[number<1>{}], + // dk_acc.get_thread_buffer()[number<2>{}], + // dk_acc.get_thread_buffer()[number<3>{}], + // dk_acc.get_thread_buffer()[number<4>{}], + // dk_acc.get_thread_buffer()[number<5>{}], + // dk_acc.get_thread_buffer()[number<6>{}], + // dk_acc.get_thread_buffer()[number<7>{}]); + // } + store_tile(ds_lds_window, ds_gemm); block_sync_lds(); @@ -977,9 +1071,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP // STAGE 6, SGrad^T@Q^T Gemm3 const auto ds_gemm = cast_tile(ds); - Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + // Policy::template SGradTFromGemm2CToGemm3A(dst_reg_tensor, ds_gemm); + dst_reg_tensor.get_thread_buffer() = ds_gemm.get_thread_buffer(); gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor); store_tile(ds_lds_window, ds_gemm); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 207c74cb56..54dc8186e4 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -371,6 +371,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() { +#if 0 constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; @@ -389,11 +390,34 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence<1, 0>>, sequence<1, 2>, sequence<2, 1>>{}); +#elif 1 + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t kMWarps = 2; + constexpr index_t kKWarps = 2; + constexpr index_t kKRow = 2; + constexpr index_t kMRow = 2; + constexpr index_t kRowsize = 16; + constexpr index_t K1 = 2; + constexpr index_t kMPair = 2; + constexpr index_t kMRepeat = 2; + constexpr index_t kMGroup = kNPerBlock/16; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 3, 2>>, + sequence<1, 1, 1, 2>, + sequence<0, 2, 4, 3>>{}); +#endif } template CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() { +#if 0 constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; @@ -412,52 +436,72 @@ struct BlockFmhaBwdPipelineDefaultPolicy tuple, sequence<2, 0>>, sequence<1, 2>, sequence<0, 1>>{}); +#elif 1 + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; + + constexpr index_t kMWarps = 2; + constexpr index_t kKWarps = 2; + constexpr index_t kKRow = 2; + constexpr index_t kMRow = 2; + constexpr index_t kRowsize = 16; + constexpr index_t K1 = 2; + constexpr index_t kMPair = 2; + constexpr index_t kMRepeat = 2; + constexpr index_t kMGroup = kNPerBlock/16; + + return make_static_tile_distribution( + tile_distribution_encoding, + tuple, + sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 3, 2>>, + sequence<1, 1, 1, 2>, + sequence<0, 2, 4, 3>>{}); +#endif } template CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t K1 = GetAlignmentQ(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M1 = get_warp_size() / K0; - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M2 = kMPerBlock / (M1 * M0); + constexpr index_t kMWarps = 2; + constexpr index_t kKWarps = 2; + constexpr index_t kKRow = 2; + constexpr index_t kMRow = 2; + constexpr index_t kRowsize = 16; + constexpr index_t K1 = 2; + constexpr index_t kMPair = 2; + constexpr index_t kMRepeat = 2; return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); + tuple, + sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 1, 2>, + sequence<1, 3, 3>>{}); } template CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - - constexpr index_t K1 = GetAlignmentOGrad(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t M1 = get_warp_size() / K0; - constexpr index_t M0 = kBlockSize / get_warp_size(); - constexpr index_t M2 = kMPerBlock / (M1 * M0); + constexpr index_t kMWarps = 2; + constexpr index_t kKWarps = 2; + constexpr index_t kKRow = 2; + constexpr index_t kMRow = 2; + constexpr index_t kRowsize = 16; + constexpr index_t K1 = GetAlignmentQ(); + constexpr index_t kMPair = 2; + constexpr index_t kMRepeat = 2; return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<1, 2>, - sequence<2, 1>>{}); + tuple, + sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 1, 2>, + sequence<1, 3, 3>>{}); } template @@ -666,6 +710,145 @@ struct BlockFmhaBwdPipelineDefaultPolicy return 16 / sizeof(GemmDataType); } + CK_TILE_HOST_DEVICE static constexpr auto Make16x128LdsBlockDescriptor() + { + constexpr index_t MWarp = 2; + constexpr index_t KWarp = 2; + constexpr index_t KRow = 2; + constexpr index_t KBit0 = 2; + constexpr index_t KBit1 = 2; + constexpr index_t KBit2 = 2; + constexpr index_t KBit3 = 2; + constexpr index_t KBit4 = 2; + constexpr index_t K1 = 2; + constexpr index_t MPair = 2; + constexpr index_t MRepeat = 2; + + // K:HeadDim, M:Seq, 11 Dimensions Total + // W T I V + // Total: 4*64*4*2 = 2^11 + + // W I T W I T T T T T V + // 2 2 2 2 2 2 2 2 2 2 2 + // KWarp, MPair, KRow, MWarp, MRepeat, KBit<1, 2, 4, 3, 0>, K1 + // M = 2^4 + // K = 2^7 + + constexpr auto lds_16x128_block_desc_raw = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple( + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto lds_16x128_block_desc = transform_tensor_descriptor( + lds_16x128_block_desc_raw, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}, number{})), + make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}))), + make_tuple(sequence<3, 4, 1>{}, sequence<0, 2, 7, 8, 6, 5, 9, 10>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + + return lds_16x128_block_desc; + } + + CK_TILE_HOST_DEVICE static constexpr auto Make16x128TransLdsBlockDescriptor() + { + constexpr index_t MWarp = 2; + constexpr index_t KWarp = 2; + constexpr index_t KRow = 2; + constexpr index_t MRow = 2; + constexpr index_t KGroup = 2; + constexpr index_t KBit0 = 2; + constexpr index_t KBit1 = 2; + constexpr index_t KBit2 = 2; + constexpr index_t K1 = 2; + constexpr index_t MPair = 2; + constexpr index_t MRepeat = 2; + + // K:HeadDim, M:Seq, 11 Dimensions Total + // W T I V + // Total: 4*64*4*2 = 2^11 + // W I T W I T T T T T V + // 2 2 2 2 2 2 2 2 2 2 2 + // Kwarp, K1, KRow, MWarp, MRepeat, , KGroup, MRow, MPair + // M = 2^4 + // K = 2^7 + + constexpr auto lds_16x128_trans_block_desc_raw = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}), + make_tuple(number{}, + // Padding + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + constexpr auto lds_16x128_trans_block_desc = transform_tensor_descriptor( + lds_16x128_trans_block_desc_raw, + make_tuple(make_merge_transform_v3_division_mod(make_tuple(number{}, + number{}, + number{}, + number{}, + number{}, + number{}, + number{})), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number{}, number{}, number{}))), + make_tuple(sequence<0, 2, 8, 6, 5, 7, 1>{}, sequence<3, 4, 9, 10>{}), + make_tuple(sequence<1>{}, sequence<0>{})); + + return lds_16x128_trans_block_desc; + } + template CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() { @@ -1008,12 +1191,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() { +#if 0 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPack = GetSmemKPackQ(); return MakeXLdsBlockDescriptor(); +#elif 1 + return Make16x128LdsBlockDescriptor(); +#endif } template @@ -1051,36 +1238,30 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - - constexpr index_t K1 = GetAlignmentQ(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentQ(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t kMWarps = 2; + constexpr index_t kKWarps = 2; + constexpr index_t kKRow = 2; + constexpr index_t kMRow = 2; + constexpr index_t kRowsize = 16; + constexpr index_t K1 = GetAlignmentQ(); + constexpr index_t kMPair = 2; + constexpr index_t kMRepeat = 2; return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + tuple, + sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 2, 1>, + sequence<1, 3, 3>>{}); } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor() { - // Hold full block data - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; - - constexpr index_t kKPack = GetAlignmentQ(); - constexpr index_t kKPackT = GetSmemKPackQT(); - - return MakeXTLdsBlockDescriptor(); + return Make16x128TransLdsBlockDescriptor(); } template @@ -1215,12 +1396,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() { // Hold full block data +#if 0 constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPack = GetSmemKPackOGrad(); return MakeXLdsBlockDescriptor(); +#elif 1 + return Make16x128LdsBlockDescriptor(); +#endif } template @@ -1258,36 +1443,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor() { - constexpr index_t kBlockSize = Problem::kBlockSize; - - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - - constexpr index_t K1 = GetAlignmentOGrad(); - constexpr index_t K0 = kKPerBlock / K1; - constexpr index_t N2 = GetTransposedAlignmentOGrad(); - constexpr index_t N1 = get_warp_size() / K0; - constexpr index_t N0 = kBlockSize / get_warp_size(); + constexpr index_t kMWarps = 2; + constexpr index_t kKWarps = 2; + constexpr index_t kKRow = 2; + constexpr index_t kMRow = 2; + constexpr index_t kRowsize = 16; + constexpr index_t K1 = GetAlignmentQ(); + constexpr index_t kMPair = 2; + constexpr index_t kMRepeat = 2; return make_static_tile_distribution( tile_distribution_encoding, - tuple, sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<1, 0>>, - sequence<2, 1>, - sequence<1, 2>>{}); + tuple, + sequence>, + tuple, sequence<2, 1, 2>>, + tuple, sequence<1, 2, 2>>, + sequence<1, 2, 1>, + sequence<1, 3, 3>>{}); } template CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor() { - // Hold all data - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; - - constexpr index_t kKPack = GetAlignmentOGrad(); - constexpr index_t kKPackT = GetSmemKPackOGradT(); - - return MakeXTLdsBlockDescriptor(); + return Make16x128TransLdsBlockDescriptor(); } template @@ -1916,18 +2094,26 @@ struct BlockFmhaBwdPipelineDefaultPolicy static constexpr index_t D_VMEM_READ = 1; // LDS Read - // 16 * 128 / 64 / 4 = 8 static constexpr index_t OGradT_LDS_READ = - kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad(); + // 16 * 128 / 64 / 4 = 8 + // kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad(); + // 16 * 128 / 64 / 8 = 4 + kM0 * kVHeaddim / get_warp_size() / 8; // 16 * 128 / 64 / 4 = 8 static constexpr index_t QT_LDS_READ = - kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); + // 16 * 128 / 64 / 4 = 8 + // kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); + // 16 * 128 / 64 / 8 = 4 + kM0 * kVHeaddim / get_warp_size() / 8; // 16 * 32 / 64 / 8 = 1 static constexpr index_t SGradT_LDS_READ_P1 = kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); // 16 * 128 / 64 / 8 = 4 static constexpr index_t Q_LDS_READ = - kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetSmemKPackQ(); + // 16 * 128 / 64 / 8 = 4 + // kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetSmemKPackQ(); + // 16 * 128 / 64 / 8 = 4 + kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / 8; // 1 static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); // 16 * 96 / 64 / 8 = 3 @@ -1935,7 +2121,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); // 16 * 128 / 64 / 8 = 4 static constexpr index_t OGrad_LDS_READ = - kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetSmemKPackOGrad(); + // kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetSmemKPackOGrad(); + kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / 8; // 1 static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp index 1fd12973f6..e989d97188 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm.hpp @@ -64,6 +64,10 @@ using WarpGemmMfmaF16F16F32M64N4K16 = WarpGemmImpl, 4>>; +using WarpGemmMfmaF16F16F32M16N32K16TransposedCDistribution = + WarpGemmImpl, + 2>>; // bf16 using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl< @@ -120,6 +124,10 @@ using WarpGemmMfmaBf16Bf16F32M64N4K16 = WarpGemmImpl, 4>>; +using WarpGemmMfmaBf16Bf16F32M16N32K16TransposedCDistribution = + WarpGemmImpl, + 2>>; // fp8 using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl< diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp index e7d4c37966..685950d6d4 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma.hpp @@ -642,6 +642,219 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution } }; +template +struct WarpGemmAtrributeMfmaIterateNAndTransposedCDistribution +{ + using Impl = remove_cvref_t; + + // swap A and B + using ADataType = typename Impl::BDataType; + using BDataType = typename Impl::ADataType; + using CDataType = typename Impl::CDataType; + + using AVecType = typename Impl::AVecType; + + using BVecType = + ext_vector_t::vector_size * kNIter>; + using CVecType = + ext_vector_t::vector_size * kNIter>; + + static constexpr index_t kM = Impl::kN; + static constexpr index_t kN = Impl::kM * kNIter; + static constexpr index_t kK = Impl::kK; + static constexpr index_t kKPerThread = Impl::kABKPerLane; + + CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kNIter; } + + static_assert(Impl::kAMBlock == 1 || Impl::kBNBlock == 1, + "Multi-block on both M & N directions is not supported"); + + CK_TILE_DEVICE static constexpr auto get_awarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // single block to multi-block thread mapping + return tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // each N blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<2>, + sequence<1>>{}; + } + } + + CK_TILE_DEVICE static constexpr auto get_bwarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + // each M blocks share the same data + return tile_distribution_encoding< + sequence, + tuple, sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + // single block to multi-block thread mapping + return tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<1, 2>, + sequence<0, 1>>{}; + } + } + + CK_TILE_DEVICE static constexpr auto get_cwarp_dstr_encoding() + { + if constexpr(Impl::kAMBlock == 1 && Impl::kBNBlock == 1) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2, 2>, + sequence<0, 1, 3>>{}; + } + else if constexpr(Impl::kAMBlock == 1 && 1 < Impl::kBNBlock) + { + return tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2, 2>, + sequence<0, 1, 3>>{}; + } + else if constexpr(1 < Impl::kAMBlock && Impl::kBNBlock == 1) + { + return tile_distribution_encoding, + tuple, + sequence>, + tuple>, + tuple>, + sequence<2, 2, 2>, + sequence<0, 1, 3>>{}; + } + } + + using AWarpDstrEncoding = decltype(get_awarp_dstr_encoding()); + + using BWarpDstrEncoding = decltype(get_bwarp_dstr_encoding()); + + using CWarpDstrEncoding = decltype(get_cwarp_dstr_encoding()); + + template + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + bool_constant = {}) const + { + using buf_c = thread_buffer; + using buf_b = thread_buffer; + // swap A and B, value and type + // Bug: result not write back. + static_for<0, kNIter, 1>{}([&](auto iNIter) { + Impl{}( + reinterpret_cast(c_vec).template get_as()(iNIter), + reinterpret_cast(b_vec) + .template get_as()[iNIter], + a_vec, + bool_constant{}); + }); + + // if(get_thread_global_1d_id()==0){ + // printf("Enter here, ") + // } + } + + template + // c_vec += a_vec * b_vec + CK_TILE_DEVICE void operator()(CVecType& c_vec, + const AVecType& a_vec, + const BVecType& b_vec, + number, + bool_constant = {}) const + { + using buf_c = thread_buffer; + using buf_b = thread_buffer; + + static_assert(iNIter < kNIter); + // swap A and B, value and type + // static_for<0, kNIter, 1>{}([&](auto iNIter) { + Impl{}(reinterpret_cast(c_vec).template get_as()(iNIter), + reinterpret_cast(b_vec) + .template get_as()[iNIter], + a_vec, + bool_constant{}); + //}); + } + + // c_vec = a_vec * b_vec + CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const + { + constexpr auto I0 = number<0>{}; + + using buf_c = thread_buffer; + buf_c c_buf; + using buf_b = thread_buffer; + + static_for<0, kNIter, 1>{}([&](auto iNIter) { + auto c_vec_tmp = Impl{}( + reinterpret_cast(b_vec).template get_as()( + iNIter), + a_vec); + + c_buf.template set_as(iNIter, c_vec_tmp); + }); + + return c_buf.template get_as()[I0]; + } +}; + template struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB { diff --git a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp index 9c319b5e5f..299e1fcd4b 100644 --- a/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp +++ b/include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp @@ -34,6 +34,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaF16F16F32M16N32K16TransposedCDistribution; }; // bf16 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; @@ -49,6 +50,7 @@ template<> struct WarpGemmMfmaDispatcher struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; +template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfmaBf16Bf16F32M16N32K16TransposedCDistribution; }; // fp8 template<> struct WarpGemmMfmaDispatcher { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };