Add kQKHeaddimForGemmN and kVHeaddimForGemmN in order to support headdim 96

This commit is contained in:
Qianfeng Zhang
2024-10-09 14:27:28 +00:00
parent 0c094daa7e
commit fb496fa2d6
8 changed files with 206 additions and 146 deletions

View File

@@ -455,6 +455,8 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'96' : [FmhaBwdDQDKDVTileSize( 16, 128, 96, 16, 96, 16, 32, 96, 96, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
"kr_ktr_vr_iglp", "kr_ktr_vr"],
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
@@ -801,4 +803,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
for kernel in kernels:
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")

View File

@@ -51,14 +51,14 @@ struct FmhaBwdDQDKDVKernel
using VGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VGradDataType>;
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
@@ -776,7 +776,7 @@ struct FmhaBwdDQDKDVKernel
number<1>{});
const auto q_dram = pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
@@ -787,7 +787,7 @@ struct FmhaBwdDQDKDVKernel
number<1>{});
const auto k_dram = pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
const auto v_dram = [&]() {
@@ -799,7 +799,7 @@ struct FmhaBwdDQDKDVKernel
number<1>{});
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}();
@@ -825,27 +825,27 @@ struct FmhaBwdDQDKDVKernel
number<1>{});
const auto do_dram = pad_tensor_view(
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
auto q_dram_window = make_tile_window(
q_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
{0, 0});
auto k_dram_window = make_tile_window(
k_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
{i_n0, 0});
auto v_dram_window = make_tile_window(
v_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
{i_n0, 0});
auto do_dram_window = make_tile_window(
do_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
{0, 0});
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
@@ -866,16 +866,16 @@ struct FmhaBwdDQDKDVKernel
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kQKHeaddimForGemmN>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
return make_tile_window(dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kQKHeaddimForGemmN>{}),
{0, 0});
}
else
{
@@ -894,16 +894,16 @@ struct FmhaBwdDQDKDVKernel
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kQKHeaddimForGemmN>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
return make_tile_window(dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kQKHeaddimForGemmN>{}),
{0, 0});
}
}();
@@ -1105,7 +1105,7 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
dk_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
@@ -1119,18 +1119,18 @@ struct FmhaBwdDQDKDVKernel
return pad_tensor_view(
dv_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}();
auto dk_dram_window = make_tile_window(
dk_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
{i_n0, 0});
auto dv_dram_window = make_tile_window(
dv_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
{i_n0, 0});
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);

View File

@@ -38,15 +38,17 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr index_t kQKHeaddimForGemmN = BlockFmhaShape::kQKHeaddimForGemmN;
static constexpr index_t kVHeaddimForGemmN = BlockFmhaShape::kVHeaddimForGemmN;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
@@ -177,8 +179,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_write_window = make_tile_window(
k_lds, make_tuple(number<kN0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
@@ -204,7 +206,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
@@ -227,14 +229,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(kt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
make_tuple(number<kQKHeaddimForGemmN>{}, number<kN0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
@@ -275,8 +277,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto q_lds_window = make_tile_window(
q_lds, make_tuple(number<kM0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
@@ -297,14 +299,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(qt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
make_tuple(number<kQKHeaddimForGemmN>{}, number<kM0>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
@@ -321,8 +323,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
auto do_lds_window = make_tile_window(
do_lds, make_tuple(number<kM0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
@@ -341,14 +343,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
auto dot_lds_read_window =
make_tile_window(dot_read_lds,
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
make_tuple(number<kVHeaddimForGemmN>{}, number<kM0>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
@@ -485,7 +487,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
//static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
@@ -728,9 +730,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
auto kt_reg_tensor_slice =
get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddimForGemmN, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)

View File

@@ -38,15 +38,17 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr index_t kQKHeaddimForGemmN = BlockFmhaShape::kQKHeaddimForGemmN;
static constexpr index_t kVHeaddimForGemmN = BlockFmhaShape::kVHeaddimForGemmN;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
@@ -177,8 +179,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_write_window = make_tile_window(
k_lds, make_tuple(number<kN0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
@@ -204,7 +206,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
@@ -227,14 +229,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(kt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
make_tuple(number<kQKHeaddimForGemmN>{}, number<kN0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
@@ -274,8 +276,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto q_lds_window = make_tile_window(
q_lds, make_tuple(number<kM0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
@@ -296,14 +298,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(qt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
make_tuple(number<kQKHeaddimForGemmN>{}, number<kM0>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
@@ -320,8 +322,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
auto do_lds_window = make_tile_window(
do_lds, make_tuple(number<kM0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
@@ -340,14 +342,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto shuffled_do_lds_write_window = make_tile_window(
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
auto dot_lds_read_window =
make_tile_window(dot_read_lds,
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
make_tuple(number<kVHeaddimForGemmN>{}, number<kM0>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
@@ -484,7 +486,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
//static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4;
@@ -763,9 +765,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
auto kt_reg_tensor_slice =
get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddimForGemmN, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
@@ -999,8 +1002,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(
kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
auto kt_reg_tensor_slice =
get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddimForGemmN, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)

View File

@@ -63,7 +63,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::OGradDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kVHeaddimForGemmN,
Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
@@ -128,7 +128,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::QDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kQKHeaddimForGemmN,
Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
@@ -160,7 +160,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
typename Problem::KDataType,
typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kQKHeaddimForGemmN,
Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
@@ -191,7 +191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
@@ -210,7 +210,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
@@ -229,7 +229,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
@@ -249,7 +249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
@@ -310,7 +310,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
@@ -322,7 +322,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentK<Problem>();
@@ -333,7 +333,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
@@ -371,7 +371,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
@@ -417,7 +417,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
@@ -440,7 +440,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
@@ -811,7 +811,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
@@ -935,7 +935,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
@@ -961,7 +961,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
@@ -982,7 +982,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKLdsWriteBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
@@ -994,7 +994,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsReadBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
auto shuffled_k_lds_block_desc = MakeShuffledKLdsWriteBlockDescriptor<Problem>();
@@ -1017,7 +1017,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
@@ -1043,7 +1043,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
@@ -1087,7 +1087,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
@@ -1108,7 +1108,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
{
// Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
@@ -1121,7 +1121,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsReadBlockDescriptor()
{
// Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
auto shuffled_q_lds_block_desc = MakeShuffledQLdsWriteBlockDescriptor<Problem>();
@@ -1144,7 +1144,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
@@ -1250,7 +1250,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
// Hold full block data
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
@@ -1294,7 +1294,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
@@ -1315,7 +1315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
@@ -1328,7 +1328,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsReadBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
auto shuffled_do_lds_block_desc = MakeShuffledOGradLdsWriteBlockDescriptor<Problem>();
@@ -1350,7 +1350,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
// constexpr index_t kNPerBlock = 32;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
@@ -1849,12 +1849,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
private:
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kQKHeaddimForGemmN = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddimForGemmN = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
static constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
@@ -1868,34 +1872,34 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute
static constexpr index_t Gemm0MFMA =
kM0 * kN0 * kQKHeaddim /
kM0 * kN0 * kK0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm1MFMA =
kM0 * kN0 * kVHeaddim /
kN0 * kVHeaddimForGemmN * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kN0 * kVHeaddim * kM0 /
kM0 * kN0 * kK2 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm3MFMA =
kN0 * kQKHeaddim * kM0 /
kN0 * kQKHeaddimForGemmN * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm4MFMA =
kM0 * kQKHeaddim * kN0 /
kM0 * kQKHeaddimForGemmN * kN0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
// VMEM
static constexpr index_t Q_VMEM_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
kM0 * kQKHeaddimForGemmN / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t OGrad_VMEM_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
kM0 * kVHeaddimForGemmN / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t LSE_VMEM_READ = 1;
static constexpr index_t D_VMEM_READ = 1;
// LDS Read
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
kM0 * kVHeaddimForGemmN / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t QT_LDS_READ =
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
kM0 * kQKHeaddimForGemmN / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t Q_LDS_READ =
@@ -1909,13 +1913,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// LDS Write
static constexpr index_t Q_LDS_WRITE =
kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ<Problem>();
kM0 * kQKHeaddimForGemmN / Problem::kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t QT_LDS_WRITE =
kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ<Problem>();
kM0 * kQKHeaddimForGemmN / kBlockSize / GetTransposedAlignmentQ<Problem>();
static constexpr index_t OGrad_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
kM0 * kVHeaddimForGemmN / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t OGradT_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
kM0 * kVHeaddimForGemmN / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t LSE_LDS_WRITE = 1;
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;

View File

@@ -55,13 +55,14 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr bool kPadHeadDimDoDv = Traits::kPadHeadDimDoDv;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
template <typename ODataType_,

View File

@@ -7,6 +7,20 @@
namespace ck_tile {
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
{
if(len == 96)
return 128;
if(len == 160)
return 256;
// only length of 96, 160 and power-of-two is supported
if(!(len & (len - 1)))
return len;
return 0;
};
template <typename BlockTile_, // sequence<...
typename Gemm0BlockWarps_,
typename Gemm0WarpTile_,
@@ -90,6 +104,13 @@ struct TileFmhaBwdShape
// K/K^T at once
static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
// that need load V at once
static constexpr index_t kQKHeaddimForGemmN = ceil_to_qualified_tile_length(kQKHeaddim);
static constexpr index_t kVHeaddimForGemmN = ceil_to_qualified_tile_length(kVHeaddim);
static_assert(kQKHeaddimForGemmN > 0, "Check failed!");
static_assert(kVHeaddimForGemmN > 0, "Check failed!");
};
} // namespace ck_tile

View File

@@ -33,6 +33,30 @@ struct TileFmhaTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_,
bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaBwdTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,