diff --git a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py index 096394c0c9..0de4912c13 100644 --- a/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py +++ b/example/ck_tile/01_fmha/codegen/ops/fmha_bwd.py @@ -455,6 +455,8 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict "kr_ktr_vr_iglp", "kr_ktr_vr"], '64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"], + '96' : [FmhaBwdDQDKDVTileSize( 16, 128, 96, 16, 96, 16, 32, 96, 96, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), + "kr_ktr_vr_iglp", "kr_ktr_vr"], '128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), "kr_ktr_vr_iglp", "kr_ktr_vr"], '256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1), @@ -801,4 +803,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im _, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl) for kernel in kernels: f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n") - f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") \ No newline at end of file + f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n") diff --git a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp index c5858a20f7..534f101854 100644 --- a/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp +++ b/include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp @@ -51,14 +51,14 @@ struct FmhaBwdDQDKDVKernel using VGradDataType = ck_tile::remove_cvref_t; using BiasGradDataType = ck_tile::remove_cvref_t; - static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; - static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; - static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; - static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; - using FmhaMask = ck_tile::remove_cvref_t; + static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode; + static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; + static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; + static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; + using FmhaMask = ck_tile::remove_cvref_t; using FmhaDropout = ck_tile::remove_cvref_t; static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasDropout = FmhaDropout::IsDropout; @@ -776,7 +776,7 @@ struct FmhaBwdDQDKDVKernel number<1>{}); const auto q_dram = pad_tensor_view( q_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); const auto k_dram_naive = make_naive_tensor_view( @@ -787,7 +787,7 @@ struct FmhaBwdDQDKDVKernel number<1>{}); const auto k_dram = pad_tensor_view( k_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); const auto v_dram = [&]() { @@ -799,7 +799,7 @@ struct FmhaBwdDQDKDVKernel number<1>{}); return pad_tensor_view( v_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); }(); @@ -825,27 +825,27 @@ struct FmhaBwdDQDKDVKernel number<1>{}); const auto do_dram = pad_tensor_view( do_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); auto q_dram_window = make_tile_window( q_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}); auto k_dram_window = make_tile_window( k_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_n0, 0}); auto v_dram_window = make_tile_window( v_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_n0, 0}); auto do_dram_window = make_tile_window( do_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}); auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() { @@ -866,16 +866,16 @@ struct FmhaBwdDQDKDVKernel number{}, number<1>{}); - return pad_tensor_view( - dq_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(dq_acc_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); }(); - return make_tile_window( - dq_acc_dram, - make_tuple(number{}, number{}), - {0, 0}); + return make_tile_window(dq_acc_dram, + make_tuple(number{}, + number{}), + {0, 0}); } else { @@ -894,16 +894,16 @@ struct FmhaBwdDQDKDVKernel number{}, number<1>{}); - return pad_tensor_view( - dq_acc_dram_naive, - make_tuple(number{}, number{}), - sequence{}); + return pad_tensor_view(dq_acc_dram_naive, + make_tuple(number{}, + number{}), + sequence{}); }(); - return make_tile_window( - dq_acc_dram, - make_tuple(number{}, number{}), - {0, 0}); + return make_tile_window(dq_acc_dram, + make_tuple(number{}, + number{}), + {0, 0}); } }(); @@ -1105,7 +1105,7 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dk_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); }(); @@ -1119,18 +1119,18 @@ struct FmhaBwdDQDKDVKernel return pad_tensor_view( dv_dram_naive, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), sequence{}); }(); auto dk_dram_window = make_tile_window( dk_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_n0, 0}); auto dv_dram_window = make_tile_window( dv_dram, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {i_n0, 0}); KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile); diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp index 131729992b..e2719a2137 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp @@ -38,15 +38,17 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK2 = BlockFmhaShape::kK2; - static constexpr index_t kK3 = BlockFmhaShape::kK3; - static constexpr index_t kK4 = BlockFmhaShape::kK4; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + static constexpr index_t kQKHeaddimForGemmN = BlockFmhaShape::kQKHeaddimForGemmN; + static constexpr index_t kVHeaddimForGemmN = BlockFmhaShape::kVHeaddimForGemmN; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; @@ -177,8 +179,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); - auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + auto k_lds_write_window = make_tile_window( + k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), @@ -204,7 +206,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), @@ -227,14 +229,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); auto kt_lds_read_window = make_tile_window(kt_lds_read, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, Policy::template MakeKTRegBlockDescriptor()); @@ -275,8 +277,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR auto q_lds = make_tensor_view( q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + auto q_lds_window = make_tile_window( + q_lds, make_tuple(number{}, number{}), {0, 0}); auto q_lds_read_window = make_tile_window(q_lds_window.get_bottom_tensor_view(), @@ -297,14 +299,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); auto shuffled_q_lds_write_window = make_tile_window( - shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); auto qt_lds_read = make_tensor_view( qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); auto qt_lds_read_window = make_tile_window(qt_lds_read, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, Policy::template MakeQTRegSliceBlockDescriptor()); @@ -321,8 +323,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR auto do_lds = make_tensor_view( do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); - auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + auto do_lds_window = make_tile_window( + do_lds, make_tuple(number{}, number{}), {0, 0}); auto do_lds_read_window = make_tile_window(do_lds_window.get_bottom_tensor_view(), @@ -341,14 +343,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); auto shuffled_do_lds_write_window = make_tile_window( - shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); auto dot_read_lds = make_tensor_view( dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); auto dot_lds_read_window = make_tile_window(dot_read_lds, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, Policy::template MakeOGradTRegSliceBlockDescriptor()); @@ -485,7 +487,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR index_t seqlen_q_step = seqlen_q_start; static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + //static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); static_assert(kM0 == kK3, "kM0 should equal to kK3"); constexpr index_t k4_loops = kN0 / kK4; @@ -728,9 +730,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR ds_reg_tensor_next = load_tile(ds_lds_read_window); move_tile_window(ds_lds_read_window, {0, kK4}); } - auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor, - sequence<0, i_k4 * kK4>{}, - sequence{}); + auto kt_reg_tensor_slice = + get_slice_tile(kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); if constexpr(i_k4 < k4_loops - 1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp index 3156e4a356..68fc5a12bc 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp @@ -38,15 +38,17 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP static constexpr index_t kBlockPerCu = Problem::kBlockPerCu; static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = BlockFmhaShape::kM0; - static constexpr index_t kN0 = BlockFmhaShape::kN0; - static constexpr index_t kK0 = BlockFmhaShape::kK0; - static constexpr index_t kK1 = BlockFmhaShape::kK1; - static constexpr index_t kK2 = BlockFmhaShape::kK2; - static constexpr index_t kK3 = BlockFmhaShape::kK3; - static constexpr index_t kK4 = BlockFmhaShape::kK4; - static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; - static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + static constexpr index_t kM0 = BlockFmhaShape::kM0; + static constexpr index_t kN0 = BlockFmhaShape::kN0; + static constexpr index_t kK0 = BlockFmhaShape::kK0; + static constexpr index_t kK1 = BlockFmhaShape::kK1; + static constexpr index_t kK2 = BlockFmhaShape::kK2; + static constexpr index_t kK3 = BlockFmhaShape::kK3; + static constexpr index_t kK4 = BlockFmhaShape::kK4; + static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; + static constexpr index_t kQKHeaddimForGemmN = BlockFmhaShape::kQKHeaddimForGemmN; + static constexpr index_t kVHeaddimForGemmN = BlockFmhaShape::kVHeaddimForGemmN; static constexpr bool kIsGroupMode = Problem::kIsGroupMode; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ; @@ -177,8 +179,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto k_lds = make_tensor_view( k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor()); - auto k_lds_write_window = - make_tile_window(k_lds, make_tuple(number{}, number{}), {0, 0}); + auto k_lds_write_window = make_tile_window( + k_lds, make_tuple(number{}, number{}), {0, 0}); auto k_lds_read_window = make_tile_window(k_lds_write_window.get_bottom_tensor_view(), @@ -204,7 +206,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor()); auto v_lds_write_window = - make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); + make_tile_window(v_lds, make_tuple(number{}, number{}), {0, 0}); auto v_lds_read_window = make_tile_window(v_lds_write_window.get_bottom_tensor_view(), @@ -227,14 +229,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor()); auto shuffled_k_lds_write_window = make_tile_window( - shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_k_lds_write, make_tuple(number{}, number{}), {0, 0}); auto kt_lds_read = make_tensor_view( kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor()); auto kt_lds_read_window = make_tile_window(kt_lds_read, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, Policy::template MakeKTRegBlockDescriptor()); @@ -274,8 +276,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto q_lds = make_tensor_view( q_lds_ptr, Policy::template MakeQLdsBlockDescriptor()); - auto q_lds_window = - make_tile_window(q_lds, make_tuple(number{}, number{}), {0, 0}); + auto q_lds_window = make_tile_window( + q_lds, make_tuple(number{}, number{}), {0, 0}); auto q_lds_read_window = make_tile_window(q_lds_window.get_bottom_tensor_view(), @@ -296,14 +298,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor()); auto shuffled_q_lds_write_window = make_tile_window( - shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_q_lds_write, make_tuple(number{}, number{}), {0, 0}); auto qt_lds_read = make_tensor_view( qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor()); auto qt_lds_read_window = make_tile_window(qt_lds_read, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, Policy::template MakeQTRegSliceBlockDescriptor()); @@ -320,8 +322,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP auto do_lds = make_tensor_view( do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor()); - auto do_lds_window = - make_tile_window(do_lds, make_tuple(number{}, number{}), {0, 0}); + auto do_lds_window = make_tile_window( + do_lds, make_tuple(number{}, number{}), {0, 0}); auto do_lds_read_window = make_tile_window(do_lds_window.get_bottom_tensor_view(), @@ -340,14 +342,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor()); auto shuffled_do_lds_write_window = make_tile_window( - shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); + shuffled_do_lds_write, make_tuple(number{}, number{}), {0, 0}); auto dot_read_lds = make_tensor_view( dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor()); auto dot_lds_read_window = make_tile_window(dot_read_lds, - make_tuple(number{}, number{}), + make_tuple(number{}, number{}), {0, 0}, Policy::template MakeOGradTRegSliceBlockDescriptor()); @@ -484,7 +486,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP index_t seqlen_q_step = seqlen_q_start; static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0"); static_assert(kM0 == kK1, "kM0 should equal to kK1"); - static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); + //static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2"); static_assert(kM0 == kK3, "kM0 should equal to kK3"); constexpr index_t k4_loops = kN0 / kK4; @@ -763,9 +765,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ds_reg_tensor_next = load_tile(ds_lds_read_window); move_tile_window(ds_lds_read_window, {0, kK4}); } - auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor, - sequence<0, i_k4 * kK4>{}, - sequence{}); + auto kt_reg_tensor_slice = + get_slice_tile(kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); if constexpr(i_k4 < k4_loops - 1) @@ -999,8 +1002,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP ds_reg_tensor_next = load_tile(ds_lds_read_window); move_tile_window(ds_lds_read_window, {0, kK4}); } - auto kt_reg_tensor_slice = get_slice_tile( - kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence{}); + auto kt_reg_tensor_slice = + get_slice_tile(kt_reg_tensor, + sequence<0, i_k4 * kK4>{}, + sequence{}); gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice); if constexpr(i_k4 < k4_loops - 1) diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp index 8647a7d25a..ca39f457e8 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp @@ -63,7 +63,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::OGradDataType, typename Problem::AccDataType, TileGemmShape, typename Problem::BlockFmhaShape::Gemm1BlockWarps, typename Problem::BlockFmhaShape::Gemm1WarpTile>>; @@ -128,7 +128,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::QDataType, typename Problem::AccDataType, TileGemmShape, typename Problem::BlockFmhaShape::Gemm3BlockWarps, typename Problem::BlockFmhaShape::Gemm3WarpTile>>; @@ -160,7 +160,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy typename Problem::KDataType, typename Problem::AccDataType, TileGemmShape, typename Problem::BlockFmhaShape::Gemm4BlockWarps, typename Problem::BlockFmhaShape::Gemm4WarpTile>>; @@ -191,7 +191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using QDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType); constexpr index_t kMinVecLoad = 4 / sizeof(QDataType); @@ -210,7 +210,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using KDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType); constexpr index_t kMinVecLoad = 4 / sizeof(KDataType); @@ -229,7 +229,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using VDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType); constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; @@ -249,7 +249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy using OGradDataType = remove_cvref_t; constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType); constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType); @@ -310,7 +310,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -322,7 +322,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; return total_pixels / GetAlignmentK(); @@ -333,7 +333,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; @@ -371,7 +371,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t K1 = GetAlignmentK(); constexpr index_t K0 = kKPerBlock / K1; @@ -417,7 +417,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; @@ -440,7 +440,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; @@ -811,7 +811,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() { constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kKPack = GetSmemKPackK(); return MakeXLdsBlockDescriptor(); @@ -935,7 +935,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; @@ -961,7 +961,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t K1 = GetAlignmentK(); constexpr index_t K0 = kKPerBlock / K1; @@ -982,7 +982,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKLdsWriteBlockDescriptor() { // Hold all data - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPack = GetSmemKPackK(); @@ -994,7 +994,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy template CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsReadBlockDescriptor() { - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; auto shuffled_k_lds_block_desc = MakeShuffledKLdsWriteBlockDescriptor(); @@ -1017,7 +1017,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{}); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); @@ -1043,7 +1043,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() { constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kKPack = GetSmemKPackQ(); @@ -1087,7 +1087,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t K1 = GetAlignmentQ(); constexpr index_t K0 = kKPerBlock / K1; @@ -1108,7 +1108,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor() { // Hold full block data - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPack = GetSmemKPackQ(); @@ -1121,7 +1121,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsReadBlockDescriptor() { // Hold full block data - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; auto shuffled_q_lds_block_desc = MakeShuffledQLdsWriteBlockDescriptor(); @@ -1144,7 +1144,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{}); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); @@ -1250,7 +1250,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { // Hold full block data constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t kKPack = GetSmemKPackOGrad(); @@ -1294,7 +1294,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy { constexpr index_t kBlockSize = Problem::kBlockSize; - constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; + constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t K1 = GetAlignmentOGrad(); constexpr index_t K0 = kKPerBlock / K1; @@ -1315,7 +1315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor() { // Hold all data - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPack = GetSmemKPackOGrad(); @@ -1328,7 +1328,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsReadBlockDescriptor() { // Hold all data - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0; auto shuffled_do_lds_block_desc = MakeShuffledOGradLdsWriteBlockDescriptor(); @@ -1350,7 +1350,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{}); - constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; + constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN; // constexpr index_t kNPerBlock = 32; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1; @@ -1849,12 +1849,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy } private: - static constexpr index_t kBlockSize = Problem::kBlockSize; - static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0; - static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; - static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; - static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; - static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; + static constexpr index_t kBlockSize = Problem::kBlockSize; + static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0; + static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; + static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; + static constexpr index_t kQKHeaddimForGemmN = Problem::BlockFmhaShape::kQKHeaddim; + static constexpr index_t kVHeaddimForGemmN = Problem::BlockFmhaShape::kVHeaddim; + static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; + static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0; + static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2; static constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}); @@ -1868,34 +1872,34 @@ struct BlockFmhaBwdPipelineDefaultPolicy // Compute static constexpr index_t Gemm0MFMA = - kM0 * kN0 * kQKHeaddim / + kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm1MFMA = - kM0 * kN0 * kVHeaddim / + kN0 * kVHeaddimForGemmN * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm2MFMA = - kN0 * kVHeaddim * kM0 / + kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm3MFMA = - kN0 * kQKHeaddim * kM0 / + kN0 * kQKHeaddimForGemmN * kM0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); static constexpr index_t Gemm4MFMA = - kM0 * kQKHeaddim * kN0 / + kM0 * kQKHeaddimForGemmN * kN0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); // VMEM static constexpr index_t Q_VMEM_READ = - kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ(); + kM0 * kQKHeaddimForGemmN / kBlockSize / GetAlignmentQ(); static constexpr index_t OGrad_VMEM_READ = - kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + kM0 * kVHeaddimForGemmN / kBlockSize / GetAlignmentOGrad(); static constexpr index_t LSE_VMEM_READ = 1; static constexpr index_t D_VMEM_READ = 1; // LDS Read static constexpr index_t OGradT_LDS_READ = - kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad(); + kM0 * kVHeaddimForGemmN / get_warp_size() / GetTransposedAlignmentOGrad(); static constexpr index_t QT_LDS_READ = - kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ(); + kM0 * kQKHeaddimForGemmN / get_warp_size() / GetTransposedAlignmentQ(); static constexpr index_t SGradT_LDS_READ_P1 = kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad(); static constexpr index_t Q_LDS_READ = @@ -1909,13 +1913,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy // LDS Write static constexpr index_t Q_LDS_WRITE = - kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ(); + kM0 * kQKHeaddimForGemmN / Problem::kBlockSize / GetAlignmentQ(); static constexpr index_t QT_LDS_WRITE = - kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ(); + kM0 * kQKHeaddimForGemmN / kBlockSize / GetTransposedAlignmentQ(); static constexpr index_t OGrad_LDS_WRITE = - kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad(); + kM0 * kVHeaddimForGemmN / kBlockSize / GetAlignmentOGrad(); static constexpr index_t OGradT_LDS_WRITE = - kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad(); + kM0 * kVHeaddimForGemmN / kBlockSize / GetTransposedAlignmentOGrad(); static constexpr index_t LSE_LDS_WRITE = 1; static constexpr index_t D_LDS_WRITE = 1; static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize; diff --git a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp index c4c4a745a7..e49b5a8a19 100644 --- a/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp @@ -55,13 +55,14 @@ struct BlockFmhaBwdPipelineProblem static constexpr bool kIsDeterministic = kIsDeterministic_; // attributes from traits - static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; - static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; - static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; - static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; - static constexpr auto BiasEnum = Traits::BiasEnum; - static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; - static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; + static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; + static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK; + static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ; + static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; + static constexpr bool kPadHeadDimDoDv = Traits::kPadHeadDimDoDv; + static constexpr auto BiasEnum = Traits::BiasEnum; + static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; + static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; }; template {}); // V headdim, used for pipeline // that need load V at once + + static constexpr index_t kQKHeaddimForGemmN = ceil_to_qualified_tile_length(kQKHeaddim); + + static constexpr index_t kVHeaddimForGemmN = ceil_to_qualified_tile_length(kVHeaddim); + + static_assert(kQKHeaddimForGemmN > 0, "Check failed!"); + static_assert(kVHeaddimForGemmN > 0, "Check failed!"); }; } // namespace ck_tile diff --git a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp index e3187042d2..695c0ddd43 100644 --- a/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp +++ b/include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp @@ -33,6 +33,30 @@ struct TileFmhaTraits static constexpr index_t kBlockPerCu = kBlockPerCu_; }; +template +struct TileFmhaBwdTraits +{ + static constexpr bool kPadSeqLenQ = kPadSeqLenQ_; + static constexpr bool kPadSeqLenK = kPadSeqLenK_; + static constexpr bool kPadHeadDimQ = kPadHeadDimQ_; + static constexpr bool kPadHeadDimV = kPadHeadDimV_; + static constexpr auto BiasEnum = BiasEnum_; + static constexpr bool kHasBiasGrad = kHasBiasGrad_; + static constexpr bool kStoreLSE = kStoreLSE_; + static constexpr bool kHasDropout = kHasDropout_; + static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; + static constexpr index_t kBlockPerCu = kBlockPerCu_; +}; + template