mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-02 21:27:45 +00:00
change q, do lds layout
This commit is contained in:
@@ -17,41 +17,33 @@
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
|
||||
// Convert DQ
|
||||
using fmha_dtype_0 = FmhaBwdFp16;
|
||||
|
||||
using fmha_bwd_convert_dq_trait_0 =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
|
||||
using fmha_bwd_convert_dq_trait_0 = ck_tile::TileFmhaBwdConvertQGradTraits<false, false, 2>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_0 =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
64,
|
||||
128,
|
||||
128,
|
||||
false,
|
||||
false,
|
||||
fmha_bwd_convert_dq_trait_0>;
|
||||
using fmha_bwd_convert_dq_pipeline_problem_0 = ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
64,
|
||||
128,
|
||||
128,
|
||||
false,
|
||||
false,
|
||||
fmha_bwd_convert_dq_trait_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_0 =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_0 =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
using fmha_bwd_convert_dq_kernel_0 = ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_0>;
|
||||
|
||||
using convert_dq_trait_0 = fmha_bwd_convert_dq_traits_<128,
|
||||
FmhaBwdFp16,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
using convert_dq_trait_0 =
|
||||
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
@@ -69,8 +61,7 @@ std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_0>()
|
||||
}
|
||||
|
||||
// dq_dk_dv
|
||||
using fmha_block_tile_0 = ck_tile::
|
||||
sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_tile_0 = ck_tile::sequence<16, 128, 128, 16, 128, 16, 32, 128, 128>;
|
||||
using fmha_block_warps0_0 = ck_tile::sequence<1, 4, 1>;
|
||||
using fmha_block_warps1_0 = ck_tile::sequence<4, 1, 1>;
|
||||
using fmha_block_warps2_0 = ck_tile::sequence<1, 4, 1>;
|
||||
@@ -82,29 +73,29 @@ using fmha_warp_tile1_0 = ck_tile::sequence<16, 16, 16>;
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_0 = ck_tile::TileFmhaBwdShape<fmha_block_tile_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps0_0,
|
||||
fmha_warp_tile0_0,
|
||||
fmha_block_warps1_0,
|
||||
fmha_warp_tile1_0,
|
||||
fmha_block_warps2_0,
|
||||
fmha_warp_tile0_0>;
|
||||
|
||||
using fmha_bwd_trait_0 = ck_tile::TileFmhaTraits<false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
1>;
|
||||
using fmha_mask_0 = ck_tile::SimplifiedGenericAttentionMask<false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
using fmha_dropout_0 = ck_tile::BlockDropoutBwd<false, true, false>;
|
||||
|
||||
using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::QDataType,
|
||||
@@ -129,7 +120,8 @@ using fmha_bwd_pipeline_problem_0 = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
fmha_dropout_0,
|
||||
fmha_bwd_trait_0>;
|
||||
|
||||
using fmha_bwd_pipeline_0 = ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
using fmha_bwd_pipeline_0 =
|
||||
ck_tile::BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP<fmha_bwd_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<FmhaBwdFp16>::AccDataType,
|
||||
@@ -143,28 +135,25 @@ using fmha_bwd_dv_epilogue_0 = ck_tile::Default2DEpilogue<
|
||||
false,
|
||||
false>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0,
|
||||
fmha_bwd_dk_epilogue_0,
|
||||
fmha_bwd_dv_epilogue_0>;
|
||||
using fmha_bwd_dq_dk_dv_kernel_0 = ck_tile::
|
||||
FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_0, fmha_bwd_dk_epilogue_0, fmha_bwd_dv_epilogue_0>;
|
||||
|
||||
using dq_dk_dv_trait_0 = fmha_bwd_dq_dk_dv_traits_<128,
|
||||
FmhaBwdFp16,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
FmhaBwdFp16,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
fmha_mask_0,
|
||||
fmha_dropout_0,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_0;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
@@ -182,8 +171,7 @@ std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_0>()
|
||||
}
|
||||
|
||||
// dot_do_o
|
||||
using fmha_bwd_dot_do_o_trait_0 =
|
||||
ck_tile::TileFmhaBwdOGradDotOTraits<false, false, 2>;
|
||||
using fmha_bwd_dot_do_o_trait_0 = ck_tile::TileFmhaBwdOGradDotOTraits<false, false, 2>;
|
||||
|
||||
using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_0>::ODataType,
|
||||
@@ -197,11 +185,9 @@ using fmha_bwd_dot_do_o_pipeline_problem_0 = ck_tile::BlockFmhaBwdOGradDotOPipel
|
||||
using fmha_bwd_dot_do_o_0 =
|
||||
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_0>;
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_0 =
|
||||
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
|
||||
using fmha_bwd_dot_do_o_kernel_0 = ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_0>;
|
||||
|
||||
using dot_do_o_trait_0 =
|
||||
fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
using dot_do_o_trait_0 = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_0>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
@@ -221,7 +207,6 @@ std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_0>()
|
||||
return k_::GetName();
|
||||
}
|
||||
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
@@ -244,25 +229,53 @@ template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_d
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", " << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", " << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); },
|
||||
[=](const ck_tile::stream_config& s_){ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); },
|
||||
[=](const ck_tile::stream_config& s_){ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }
|
||||
);
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << ", "
|
||||
<< fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << ", "
|
||||
<< fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(
|
||||
s,
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a);
|
||||
},
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a);
|
||||
},
|
||||
[=](const ck_tile::stream_config& s_) {
|
||||
fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a);
|
||||
});
|
||||
}
|
||||
|
||||
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){
|
||||
float fmha_bwd(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s)
|
||||
{
|
||||
float r = -1;
|
||||
if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) && (t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) && (t.has_dbias == false) && (t.has_dropout == false) &&
|
||||
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) && (a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false)) {
|
||||
if(t.data_type.compare("fp16") == 0 && (t.is_group_mode == false) &&
|
||||
(t.mask_type == mask_enum::no_mask) && (t.bias_type == bias_enum::no_bias) &&
|
||||
(t.has_dbias == false) && (t.has_dropout == false) &&
|
||||
(a.seqlen_q % 16 == 0 and a.seqlen_q % 64 == 0) && (a.seqlen_k % 128 == 0) &&
|
||||
(a.hdim_q % 128 == 0) && (a.hdim_v % 128 == 0) && (t.is_deterministic == false))
|
||||
{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
using dq_dk_dv_trait_ =
|
||||
fmha_bwd_dq_dk_dv_traits_<128,
|
||||
FmhaBwdFp16,
|
||||
false,
|
||||
ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
ck_tile::SimplifiedGenericAttentionMask<false>,
|
||||
ck_tile::BlockDropoutBwd<false, true, false>,
|
||||
ck_tile::BlockAttentionBiasEnum::NO_BIAS,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false>;
|
||||
using convert_dq_trait_ =
|
||||
fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
|
||||
return r;
|
||||
}
|
||||
else{
|
||||
else
|
||||
{
|
||||
assert("unsupported case\n");
|
||||
return r;
|
||||
}
|
||||
@@ -806,11 +819,13 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
float ave_time = fmha_bwd(fmha_traits, fmha_args, stream_config);
|
||||
// using instance:
|
||||
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false, ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP, ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>, ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>;
|
||||
// using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false, false>;
|
||||
// r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a);
|
||||
// return r;
|
||||
// using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<128, FmhaBwdFp16, false, false, false>;
|
||||
// using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<128, FmhaBwdFp16, false,
|
||||
// ck_tile::BlockFmhaBwdPipelineEnum::KRKTRVR_IGLP,
|
||||
// ck_tile::SimplifiedGenericAttentionMask<false>, ck_tile::BlockDropoutBwd<false, true, false>,
|
||||
// ck_tile::BlockAttentionBiasEnum::NO_BIAS, false, false, false, false, false, false>; using
|
||||
// convert_dq_trait_ = fmha_bwd_convert_dq_traits_<128, FmhaBwdFp16, false, false, false,
|
||||
// false>; r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, convert_dq_trait_>(s, a); return r;
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << ", not supported yet" << std::flush << std::endl;
|
||||
|
||||
@@ -30,8 +30,13 @@ auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType,
|
||||
typename ALayout, typename BLayout, typename CLayout>
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
@@ -57,9 +62,9 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
args.stride_B = stride_B;
|
||||
args.stride_C = stride_C;
|
||||
|
||||
float ave_time = gemm_calc<ADataType, BDataType, AccDataType, CDataType,
|
||||
ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
float ave_time =
|
||||
gemm_calc<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat});
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
@@ -69,14 +74,11 @@ float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
|
||||
std::cout << "Run Gemm kernel with M =" << M << " N =" << N << " K =" << K
|
||||
<< " StrideA =" << stride_A << " StrideB =" << stride_B << " StrideC =" << stride_C
|
||||
<< " A_Layout =" << ALayout::name
|
||||
<< " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name
|
||||
<< " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name
|
||||
<< " : " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< std::endl;
|
||||
<< " A_Layout =" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout =" << CLayout::name << " A Type = " << DataTypeTraits<ADataType>::name
|
||||
<< " B Type = " << DataTypeTraits<BDataType>::name
|
||||
<< " C Type = " << DataTypeTraits<CDataType>::name << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
@@ -92,10 +94,10 @@ int run_gemm_example_with_layouts(int argc,
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
using ADataType = typename GemmBasicTypeConfig<PrecType>::ADataType;
|
||||
using BDataType = typename GemmBasicTypeConfig<PrecType>::BDataType;
|
||||
using CDataType = typename GemmBasicTypeConfig<PrecType>::CDataType;
|
||||
using AccDataType = typename GemmBasicTypeConfig<PrecType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
@@ -133,19 +135,19 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<ADataType, BDataType, AccDataType, CDataType,
|
||||
ALayout, BLayout, CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
invoke_gemm<ADataType, BDataType, AccDataType, CDataType, ALayout, BLayout, CLayout>(
|
||||
a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
@@ -160,9 +162,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
|
||||
(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
@@ -218,9 +220,9 @@ int run_gemm_example_with_layouts(int argc,
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>
|
||||
(K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_gpu_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
|
||||
@@ -137,6 +137,9 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// if (threadIdx.x == 0){
|
||||
// HotLoopScheduler::print();
|
||||
// }
|
||||
// Block GEMM
|
||||
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
|
||||
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
|
||||
@@ -532,7 +535,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
// Hot loop
|
||||
while(i_total_loops < (num_total_loop - 1))
|
||||
{
|
||||
// STAGE 1, Q@K Gemm0
|
||||
// STAGE 1, Q@K Gemm0
|
||||
d_block_tile = load_tile(d_dram_window);
|
||||
move_tile_window(d_dram_window, {kM0});
|
||||
|
||||
@@ -664,7 +667,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
// decltype(p_gemm)>(pt_reg_tensor, p_gemm);
|
||||
|
||||
pt_reg_tensor.get_thread_buffer() = p_gemm.get_thread_buffer();
|
||||
auto qt_reg_tensor = load_tile(qt_lds_read_window);
|
||||
auto qt_reg_tensor = load_tile(qt_lds_read_window);
|
||||
|
||||
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
|
||||
|
||||
HotLoopScheduler::template GemmStagedScheduler<1>();
|
||||
|
||||
@@ -202,9 +202,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: kMinVecLoad;
|
||||
constexpr index_t kVecLoad =
|
||||
((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : kMinVecLoad;
|
||||
|
||||
return kVecLoad;
|
||||
}
|
||||
@@ -260,9 +259,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
|
||||
? kMaxVecLoad
|
||||
: kMinVecLoad;
|
||||
constexpr index_t kVecLoad =
|
||||
((total_pixels / kMaxVecLoad) >= kMinVecLoad) ? kMaxVecLoad : kMinVecLoad;
|
||||
|
||||
return kVecLoad;
|
||||
}
|
||||
@@ -607,7 +605,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
|
||||
{
|
||||
return GetAlignmentQ<Problem>();
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
return 16 / sizeof(QDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -649,7 +648,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad()
|
||||
{
|
||||
return GetAlignmentOGrad<Problem>();
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
return 16 / sizeof(OGradDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -666,48 +666,73 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
return 16 / sizeof(GemmDataType);
|
||||
}
|
||||
|
||||
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
|
||||
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, bool XorLdsLayout = true>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
|
||||
{
|
||||
constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
|
||||
constexpr auto MNLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
if constexpr(XorLdsLayout)
|
||||
{
|
||||
constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
|
||||
constexpr auto MNLdsLayer =
|
||||
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
|
||||
number<MNPerBlock / MNLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
|
||||
number<MNPerBlock / MNLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
x_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
|
||||
number<KPerBlock / KPack * MNLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
x_lds_block_desc_0,
|
||||
make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
|
||||
number<KPerBlock / KPack * MNLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
x_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
|
||||
make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
x_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
|
||||
make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
|
||||
x_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
|
||||
x_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return x_lds_block_desc;
|
||||
return x_lds_block_desc;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<MNPerBlock>{},
|
||||
number<KPerBlock / 64>{},
|
||||
number<64 / KPack>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPerBlock / 64 * (64 / KPack + 1) * KPack>{},
|
||||
number<(64 / KPack + 1) * KPack>{},
|
||||
number<KPack>{},
|
||||
number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
x_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(number<MNPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<KPerBlock / 64>{}, number<64 / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem,
|
||||
@@ -986,9 +1011,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
constexpr index_t KPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1193,9 +1218,9 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
constexpr index_t KPack = GetSmemKPackOGrad<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
|
||||
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, KPack, false>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
@@ -1681,14 +1706,17 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
|
||||
constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
|
||||
// To hide instruction issue latency
|
||||
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
|
||||
constexpr index_t LDS_READ_PER_MFMA =
|
||||
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
|
||||
|
||||
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr (i * MFMA_PER_VMEM_READ + j<LDS_READ_INST){
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(i * MFMA_PER_VMEM_READ + j < LDS_READ_INST)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, LDS_READ_PER_MFMA, 0); // DS read
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1709,11 +1737,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MFMA_INST = Gemm1MFMA;
|
||||
|
||||
// To hide instruction issue latency
|
||||
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
|
||||
constexpr index_t LDS_READ_PER_MFMA =
|
||||
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
|
||||
|
||||
static_for<0, MFMA_INST, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr (i <LDS_READ_INST){
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(i < LDS_READ_INST)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
|
||||
}
|
||||
});
|
||||
@@ -1729,11 +1759,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MFMA_INST = Gemm2MFMA;
|
||||
|
||||
// To hide instruction issue latency
|
||||
constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
|
||||
constexpr index_t LDS_WRITE_PER_MFMA =
|
||||
ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
|
||||
|
||||
static_for<0, MFMA_INST, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr (i < LDS_WRITE_INST){
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(i < LDS_WRITE_INST)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
|
||||
}
|
||||
});
|
||||
@@ -1749,31 +1781,43 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MFMA_INST = Gemm3MFMA;
|
||||
|
||||
// To hide instruction issue latency
|
||||
constexpr index_t LDS_WRITE_PER_MFMA = ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
|
||||
constexpr index_t LDS_WRITE_PER_MFMA =
|
||||
ck_tile::integer_divide_ceil(LDS_WRITE_INST, MFMA_INST);
|
||||
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
|
||||
|
||||
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE));
|
||||
constexpr index_t LDS_READ_PER_MFMA =
|
||||
ck_tile::integer_divide_ceil(LDS_READ_INST, (MFMA_INST - MFMA_INST_LDS_WRITE));
|
||||
|
||||
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr (i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST){
|
||||
if constexpr ( (i +1 ) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST){
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(i * LDS_WRITE_PER_MFMA < LDS_WRITE_INST)
|
||||
{
|
||||
if constexpr((i + 1) * LDS_WRITE_PER_MFMA > LDS_WRITE_INST)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x200, LDS_WRITE_INST - i * LDS_WRITE_PER_MFMA, 0); // DS Write
|
||||
}
|
||||
else{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){
|
||||
if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(i * LDS_READ_PER_MFMA < LDS_READ_INST)
|
||||
{
|
||||
if constexpr((i + 1) * LDS_READ_PER_MFMA > LDS_READ_INST)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
|
||||
}
|
||||
else{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, LDS_READ_PER_MFMA, 0); // DS Read
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -1788,21 +1832,42 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MFMA_INST = Gemm4MFMA;
|
||||
|
||||
// To hide instruction issue latency
|
||||
constexpr index_t LDS_READ_PER_MFMA = ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
|
||||
constexpr index_t LDS_READ_PER_MFMA =
|
||||
ck_tile::integer_divide_ceil(LDS_READ_INST, MFMA_INST);
|
||||
|
||||
static_for<0, MFMA_INST, 1>{}([&](auto i) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr (i * LDS_READ_PER_MFMA < LDS_READ_INST){
|
||||
if constexpr ( (i +1 ) * LDS_READ_PER_MFMA > LDS_READ_INST){
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
if constexpr(i * LDS_READ_PER_MFMA < LDS_READ_INST)
|
||||
{
|
||||
if constexpr((i + 1) * LDS_READ_PER_MFMA > LDS_READ_INST)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, LDS_READ_INST - i * LDS_READ_PER_MFMA, 0); // DS Read
|
||||
}
|
||||
else{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, LDS_READ_PER_MFMA, 0); // DS Read
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static void print()
|
||||
{
|
||||
printf("LDS instruction{");
|
||||
//
|
||||
printf("OGradT_LDS_READ: %d, ", OGradT_LDS_READ);
|
||||
printf("OGrad_LDS_READ: %d, ", OGrad_LDS_READ);
|
||||
printf("QT_LDS_READ: %d, ", QT_LDS_READ);
|
||||
printf("Q_LDS_READ: %d, ", Q_LDS_READ);
|
||||
printf("SGradT_LDS_READ_P1: %d, ", SGradT_LDS_READ_P1);
|
||||
printf("SGradT_LDS_READ_P2: %d, ", SGradT_LDS_READ_P2);
|
||||
printf("LSE_LDS_READ: %d, ", LSE_LDS_READ);
|
||||
printf("D_LDS_READ: %d, ", D_LDS_READ);
|
||||
printf("}");
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
|
||||
@@ -1818,6 +1883,10 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
static constexpr index_t WarpGemmN =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
|
||||
static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8;
|
||||
static constexpr index_t Gemm0MWarp =
|
||||
Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
static constexpr index_t Gemm2MWarp =
|
||||
Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
|
||||
static constexpr index_t Gemm4MWarp =
|
||||
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
|
||||
static constexpr index_t Gemm4NWarp =
|
||||
@@ -1847,20 +1916,29 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
static constexpr index_t D_VMEM_READ = 1;
|
||||
|
||||
// LDS Read
|
||||
// 16 * 128 / 64 / 4 = 8
|
||||
static constexpr index_t OGradT_LDS_READ =
|
||||
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
|
||||
// 16 * 128 / 64 / 4 = 8
|
||||
static constexpr index_t QT_LDS_READ =
|
||||
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
// 16 * 32 / 64 / 8 = 1
|
||||
static constexpr index_t SGradT_LDS_READ_P1 =
|
||||
// kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / 2;
|
||||
static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
|
||||
// 16 * 128 / 64 / 8 = 4
|
||||
static constexpr index_t Q_LDS_READ =
|
||||
kM0 * kK0 / (get_warp_size() * Gemm0MWarp) / GetAlignmentQ<Problem>();
|
||||
// 1
|
||||
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
// 16 * 96 / 64 / 8 = 3
|
||||
static constexpr index_t SGradT_LDS_READ_P2 =
|
||||
// kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / 2;
|
||||
// 16 * 128 / 64 / 8 = 4
|
||||
static constexpr index_t OGrad_LDS_READ =
|
||||
kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
kM0 * kK2 / (get_warp_size() * Gemm2MWarp) / GetAlignmentOGrad<Problem>();
|
||||
// 1
|
||||
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
|
||||
|
||||
// LDS Write
|
||||
|
||||
Reference in New Issue
Block a user