diff --git a/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp b/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp index af39d1271e..f257e3d73e 100644 --- a/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp +++ b/example/ck_tile/01_fmha/example_bwd_fmha_bf16.cpp @@ -17,41 +17,33 @@ #include #include - // Convert DQ using fmha_dtype_0 = FmhaBwdFp16; -using fmha_bwd_convert_dq_trait_0 = - ck_tile::TileFmhaBwdConvertQGradTraits; +using fmha_bwd_convert_dq_trait_0 = ck_tile::TileFmhaBwdConvertQGradTraits; -using fmha_bwd_convert_dq_pipeline_problem_0 = - ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< - typename FmhaBwdTypeConfig::AccDataType, - typename FmhaBwdTypeConfig::QGradDataType, - /* BlockSize = */ 256, - 64, - 128, - 128, - false, - false, - fmha_bwd_convert_dq_trait_0>; +using fmha_bwd_convert_dq_pipeline_problem_0 = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem< + typename FmhaBwdTypeConfig::AccDataType, + typename FmhaBwdTypeConfig::QGradDataType, + /* BlockSize = */ 256, + 64, + 128, + 128, + false, + false, + fmha_bwd_convert_dq_trait_0>; using fmha_bwd_convert_dq_0 = typename ck_tile::BlockFmhaBwdConvertQGrad; -using fmha_bwd_convert_dq_kernel_0 = - ck_tile::FmhaBwdConvertQGradKernel; +using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel; -using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128, - FmhaBwdFp16, - false, - false, - false, - false>; +using convert_dq_trait_0 = + fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; template <> void fmha_bwd_convert_dq_oneshot_(const ck_tile::stream_config& s, - fmha_bwd_args a) + fmha_bwd_args a) { using k_ = fmha_bwd_convert_dq_kernel_0; auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids(a); @@ -69,8 +61,7 @@ std::string fmha_bwd_convert_dq_get_name_() } // dq_dk_dv -using fmha_block_tile_0 = ck_tile:: - sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; +using fmha_block_tile_0 = ck_tile::sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>; using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>; using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>; using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>; @@ -82,29 +73,29 @@ using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>; // G1&G3 -> GdKV // G4 -> GdQ using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape; + fmha_block_warps0_0, + fmha_warp_tile0_0, + fmha_block_warps1_0, + fmha_warp_tile1_0, + fmha_block_warps0_0, + fmha_warp_tile0_0, + fmha_block_warps1_0, + fmha_warp_tile1_0, + fmha_block_warps2_0, + fmha_warp_tile0_0>; using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits; + false, + false, + false, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + 1>; using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask; -using fmha_dropout_0 = ck_tile::BlockDropoutBwd; +using fmha_dropout_0 = ck_tile::BlockDropoutBwd; using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< typename FmhaBwdTypeConfig::QDataType, @@ -129,7 +120,8 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem< fmha_dropout_0, fmha_bwd_trait_0>; -using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; +using fmha_bwd_pipeline_0 = + ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP; using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogueProblem::AccDataType, @@ -143,28 +135,25 @@ using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue< false, false>>; -using fmha_bwd_dq_dk_dv_kernel_0 = - ck_tile::FmhaBwdDQDKDVKernel; +using fmha_bwd_dq_dk_dv_kernel_0 = ck_tile:: + FmhaBwdDQDKDVKernel; using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128, - FmhaBwdFp16, - false, - ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, - fmha_mask_0, - fmha_dropout_0, - ck_tile::BlockAttentionBiasEnum::NO_BIAS, - false, - false, - false, - false, - false, - false>; + FmhaBwdFp16, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + fmha_mask_0, + fmha_dropout_0, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; template <> -void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, - fmha_bwd_args a) +void fmha_bwd_dq_dk_dv_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) { using k_ = fmha_bwd_dq_dk_dv_kernel_0; auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids(a); @@ -182,8 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_() } // dot_do_o -using fmha_bwd_dot_do_o_trait_0 = - ck_tile::TileFmhaBwdOGradDotOTraits; +using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits; using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem< typename FmhaBwdTypeConfig::ODataType, @@ -197,11 +185,9 @@ using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipel using fmha_bwd_dot_do_o_0 = typename ck_tile::BlockFmhaBwdOGradDotO; -using fmha_bwd_dot_do_o_kernel_0 = - ck_tile::FmhaBwdOGradDotOKernel; +using fmha_bwd_dot_do_o_kernel_0 = ck_tile::FmhaBwdOGradDotOKernel; -using dot_do_o_trait_0 = - fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; +using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; template <> void fmha_bwd_dot_do_o_oneshot_(const ck_tile::stream_config& s, fmha_bwd_args a) @@ -221,7 +207,6 @@ std::string fmha_bwd_dot_do_o_get_name_() return k_::GetName(); } - template std::ostream& operator<<(std::ostream& os, const std::vector& v) { @@ -244,25 +229,53 @@ template 0) - std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " << fmha_bwd_dq_dk_dv_get_name_() << ", " << fmha_bwd_convert_dq_get_name_() << std::flush; - return ck_tile::launch_kernel(s, - [=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_(s_, a); }, - [=](const ck_tile::stream_config& s_){ fmha_bwd_dq_dk_dv_oneshot_(s_, a); }, - [=](const ck_tile::stream_config& s_){ fmha_bwd_convert_dq_oneshot_(s_, a); } - ); + std::cout << ", " << fmha_bwd_dot_do_o_get_name_() << ", " + << fmha_bwd_dq_dk_dv_get_name_() << ", " + << fmha_bwd_convert_dq_get_name_() << std::flush; + return ck_tile::launch_kernel( + s, + [=](const ck_tile::stream_config& s_) { + fmha_bwd_dot_do_o_oneshot_(s_, a); + }, + [=](const ck_tile::stream_config& s_) { + fmha_bwd_dq_dk_dv_oneshot_(s_, a); + }, + [=](const ck_tile::stream_config& s_) { + fmha_bwd_convert_dq_oneshot_(s_, a); + }); } -float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){ +float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s) +{ float r = -1; - if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) && - (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) { + if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) && + (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && + (t.has_dbias == false) && (t.has_dropout == false) && + (a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && + (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) + { using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; - using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; + using dq_dk_dv_trait_ = + fmha_bwd_dq_dk_dv_traits_<128, + FmhaBwdFp16, + false, + ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + ck_tile::SimplifiedGenericAttentionMask, + ck_tile::BlockDropoutBwd, + ck_tile::BlockAttentionBiasEnum::NO_BIAS, + false, + false, + false, + false, + false, + false>; + using convert_dq_trait_ = + fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; r = fmha_bwd_(s, a); return r; } - else{ + else + { assert("unsupported case\n"); return r; } @@ -806,11 +819,13 @@ bool run(const ck_tile::ArgParser& arg_parser) float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config); // using instance: - // using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; - // using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; - // using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>; - // r = fmha_bwd_(s, a); - // return r; + // using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>; + // using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, + // ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, + // ck_tile::SimplifiedGenericAttentionMask, ck_tile::BlockDropoutBwd, + // ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; using + // convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, + // false>; r = fmha_bwd_(s, a); return r; if(ave_time < 0) { std::cout << ", not supported yet" << std::flush << std::endl; diff --git a/example/ck_tile/03_gemm/run_gemm_example.inc b/example/ck_tile/03_gemm/run_gemm_example.inc index 028f8a44c3..8af83532b3 100644 --- a/example/ck_tile/03_gemm/run_gemm_example.inc +++ b/example/ck_tile/03_gemm/run_gemm_example.inc @@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K, return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); } -template +template float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, ck_tile::DeviceMem& b_k_n_dev_buf, ck_tile::DeviceMem& c_m_n_dev_buf, @@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, args.stride_B = stride_B; args.stride_C = stride_C; - float ave_time = gemm_calc( - args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); + float ave_time = + gemm_calc( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat}); std::size_t flop = std::size_t(2) * M * N * K; std::size_t num_byte = @@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K << " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C - << " A_Layout =" << ALayout::name - << " B_Layout =" << BLayout::name - << " C_Layout =" << CLayout::name - << " A Type = " << DataTypeTraits::name - << " B Type = " << DataTypeTraits::name - << " C Type = " << DataTypeTraits::name - << " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " - << std::endl; + << " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits::name + << " B Type = " << DataTypeTraits::name + << " C Type = " << DataTypeTraits::name << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; return ave_time; } @@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc, if(!result) return -1; - using ADataType = typename GemmBasicTypeConfig::ADataType; - using BDataType = typename GemmBasicTypeConfig::BDataType; - using CDataType = typename GemmBasicTypeConfig::CDataType; - using AccDataType = typename GemmBasicTypeConfig::AccDataType; + using ADataType = typename GemmBasicTypeConfig::ADataType; + using BDataType = typename GemmBasicTypeConfig::BDataType; + using CDataType = typename GemmBasicTypeConfig::CDataType; + using AccDataType = typename GemmBasicTypeConfig::AccDataType; ck_tile::index_t M = arg_parser.get_int("m"); ck_tile::index_t N = arg_parser.get_int("n"); @@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc, c_m_n_dev_buf.SetZero(); c_m_n_dev_result.SetZero(); - invoke_gemm(a_m_k_dev_buf, - b_k_n_dev_buf, - c_m_n_dev_buf, - M, - N, - K, - stride_A, - stride_B, - stride_C, - kbatch, - n_warmup, - n_repeat); + invoke_gemm( + a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat); c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); bool pass = true; @@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc, a_m_k, b_k_n, c_m_n_host_ref); const float max_accumulated_value = *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol - (K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_host_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), @@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc, c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); const float max_accumulated_value = *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); - const auto rtol_atol = calculate_rtol_atol - (K, kbatch, max_accumulated_value); - pass = ck_tile::check_err(c_m_n_dev_result, + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, c_m_n_gpu_ref, "Error: Incorrect results!", rtol_atol.at(ck_tile::number<0>{}), diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 6220a59ea1..f397fe32f9 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -137,6 +137,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], "wrong!"); + // if (threadIdx.x == 0){ + // HotLoopScheduler::print(); + // } // Block GEMM constexpr auto gemm_0 = Policy::template GetQKBlockGemm(); constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm(); @@ -532,7 +535,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP // Hot loop while(i_total_loops < (num_total_loop - 1)) { - // STAGE 1, Q@K Gemm0 + // STAGE 1, Q@K Gemm0 d_block_tile = load_tile(d_dram_window); move_tile_window(d_dram_window, {kM0}); @@ -664,7 +667,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP // decltype(p_gemm)>(pt_reg_tensor, p_gemm); pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer(); - auto qt_reg_tensor = load_tile(qt_lds_read_window); + auto qt_reg_tensor = load_tile(qt_lds_read_window); + gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor); HotLoopScheduler::template GemmStagedScheduler<1>(); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 0f21f77992..7b7028af10 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -202,9 +202,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) - ? kMaxVecLoad - : kMinVecLoad; + constexpr index_t kVecLoad = + ((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : kMinVecLoad; return kVecLoad; } @@ -260,9 +259,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; - constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad) - ? kMaxVecLoad - : kMinVecLoad; + constexpr index_t kVecLoad = + ((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : kMinVecLoad; return kVecLoad; } @@ -607,7 +605,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() { - return GetAlignmentQ(); + using QDataType = remove_cvref_t; + return 16 / sizeof(QDataType); } template @@ -649,7 +648,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() { - return GetAlignmentOGrad(); + using OGradDataType = remove_cvref_t; + return 16 / sizeof(OGradDataType); } template @@ -666,48 +666,73 @@ struct BlockFmhaBwdPipelineDefaultPolicy return 16 / sizeof(GemmDataType); } - template + template CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor() { - constexpr auto DataTypeSize = 2; // sizeof(F16/BF16) - constexpr auto MNLdsLayer = - (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); + if constexpr(XorLdsLayout) + { + constexpr auto DataTypeSize = 2; // sizeof(F16/BF16) + constexpr auto MNLdsLayer = + (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); - constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( - make_tuple(number{}, - number{}, - number{}), - make_tuple(number{}, number{}, number<1>{}), - number{}, - number<1>{}); + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number{}), + make_tuple(number{}, number{}, number<1>{}), + number{}, + number<1>{}); - constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor( - x_lds_block_desc_0, - make_tuple(make_xor_transform(make_tuple(number{}, - number{})), - make_pass_through_transform(number{})), - make_tuple(sequence<1, 0>{}, sequence<2>{}), - make_tuple(sequence<1, 0>{}, sequence<2>{})); + constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_xor_transform(make_tuple(number{}, + number{})), + make_pass_through_transform(number{})), + make_tuple(sequence<1, 0>{}, sequence<2>{}), + make_tuple(sequence<1, 0>{}, sequence<2>{})); - constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( - x_lds_block_desc_permuted, - make_tuple(make_unmerge_transform( - make_tuple(number{}, number{})), - make_pass_through_transform(number{}), - make_pass_through_transform(number{})), - make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), - make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); + constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( + x_lds_block_desc_permuted, + make_tuple(make_unmerge_transform( + make_tuple(number{}, number{})), + make_pass_through_transform(number{}), + make_pass_through_transform(number{})), + make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), + make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); - constexpr auto x_lds_block_desc = transform_tensor_descriptor( - x_lds_block_desc_xk0_mnldslayer_mn_xk1, - make_tuple(make_merge_transform_v3_division_mod( - make_tuple(number{}, number{})), - make_merge_transform_v3_division_mod( - make_tuple(number{}, number{}))), - make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), - make_tuple(sequence<0>{}, sequence<1>{})); + constexpr auto x_lds_block_desc = transform_tensor_descriptor( + x_lds_block_desc_xk0_mnldslayer_mn_xk1, + make_tuple(make_merge_transform_v3_division_mod( + make_tuple(number{}, number{})), + make_merge_transform_v3_division_mod( + make_tuple(number{}, number{}))), + make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); - return x_lds_block_desc; + return x_lds_block_desc; + } + else + { + constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( + make_tuple(number{}, + number{}, + number<64 / KPack>{}, + number{}), + make_tuple(number{}, + number<(64 / KPack + 1) * KPack>{}, + number{}, + number<1>{}), + number{}, + number<1>{}); + + return transform_tensor_descriptor( + x_lds_block_desc_0, + make_tuple(make_pass_through_transform(number{}), + make_merge_transform_v3_division_mod(make_tuple( + number{}, number<64 / KPack>{}, number{}))), + make_tuple(sequence<0>{}, sequence<1, 2, 3>{}), + make_tuple(sequence<0>{}, sequence<1>{})); + } } template (); + constexpr index_t KPack = GetSmemKPackQ(); - return MakeXLdsBlockDescriptor(); + return MakeXLdsBlockDescriptor(); } template @@ -1193,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; - constexpr index_t kKPack = GetSmemKPackOGrad(); + constexpr index_t KPack = GetSmemKPackOGrad(); - return MakeXLdsBlockDescriptor(); + return MakeXLdsBlockDescriptor(); } template @@ -1681,14 +1706,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST; constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST; // To hide instruction issue latency - constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST); + constexpr index_t LDS_READ_PER_MFMA = + ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST); static_for<0, VMEM_READ_INST, 1>{}([&](auto i) { __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr (i * MFMA_PER_VMEM_READ + j{}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr (i {}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr (i < LDS_WRITE_INST){ + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i < LDS_WRITE_INST) + { __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write } }); @@ -1749,31 +1781,43 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MFMA_INST = Gemm3MFMA; // To hide instruction issue latency - constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST); + constexpr index_t LDS_WRITE_PER_MFMA = + ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST); constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA; - constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE)); + constexpr index_t LDS_READ_PER_MFMA = + ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE)); static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr (i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST){ - if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){ - __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST) + { + if constexpr((i + 1) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST) + { + __builtin_amdgcn_sched_group_barrier( + 0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write } - else{ - __builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write + else + { + __builtin_amdgcn_sched_group_barrier( + 0x200, LDS_WRITE_PER_MFMA, 0); // DS Write } } }); static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){ - if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){ - __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i * LDS_READ_PER_MFMA < LDS_READ_INST) + { + if constexpr((i + 1) * LDS_READ_PER_MFMA > LDS_READ_INST) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read } - else{ - __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, LDS_READ_PER_MFMA, 0); // DS Read } } }); @@ -1788,21 +1832,42 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MFMA_INST = Gemm4MFMA; // To hide instruction issue latency - constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST); + constexpr index_t LDS_READ_PER_MFMA = + ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST); static_for<0, MFMA_INST, 1>{}([&](auto i) { - __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){ - if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){ - __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read + __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA + if constexpr(i * LDS_READ_PER_MFMA < LDS_READ_INST) + { + if constexpr((i + 1) * LDS_READ_PER_MFMA > LDS_READ_INST) + { + __builtin_amdgcn_sched_group_barrier( + 0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read } - else{ - __builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read + else + { + __builtin_amdgcn_sched_group_barrier( + 0x100, LDS_READ_PER_MFMA, 0); // DS Read } } }); } + CK_TILE_HOST_DEVICE static void print() + { + printf("LDS instruction{"); + // + printf("OGradT_LDS_READ: %d, ", OGradT_LDS_READ); + printf("OGrad_LDS_READ: %d, ", OGrad_LDS_READ); + printf("QT_LDS_READ: %d, ", QT_LDS_READ); + printf("Q_LDS_READ: %d, ", Q_LDS_READ); + printf("SGradT_LDS_READ_P1: %d, ", SGradT_LDS_READ_P1); + printf("SGradT_LDS_READ_P2: %d, ", SGradT_LDS_READ_P2); + printf("LSE_LDS_READ: %d, ", LSE_LDS_READ); + printf("D_LDS_READ: %d, ", D_LDS_READ); + printf("}"); + } + private: static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0; @@ -1818,6 +1883,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy static constexpr index_t WarpGemmN = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}); static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8; + static constexpr index_t Gemm0MWarp = + Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{}); + static constexpr index_t Gemm2MWarp = + Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{}); static constexpr index_t Gemm4MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); static constexpr index_t Gemm4NWarp = @@ -1847,20 +1916,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy static constexpr index_t D_VMEM_READ = 1; // LDS Read + // 16 * 128 / 64 / 4 = 8 static constexpr index_t OGradT_LDS_READ = kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad(); + // 16 * 128 / 64 / 4 = 8 static constexpr index_t QT_LDS_READ = kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); + // 16 * 32 / 64 / 8 = 1 static constexpr index_t SGradT_LDS_READ_P1 = // kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2; - static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ(); + // 16 * 128 / 64 / 8 = 4 + static constexpr index_t Q_LDS_READ = + kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetAlignmentQ(); + // 1 static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); + // 16 * 96 / 64 / 8 = 3 static constexpr index_t SGradT_LDS_READ_P2 = // kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2; + // 16 * 128 / 64 / 8 = 4 static constexpr index_t OGrad_LDS_READ = - kM0 * kK2 / kBlockSize / GetAlignmentOGrad(); + kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetAlignmentOGrad(); + // 1 static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); // LDS Write