mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 05:55:39 +00:00
Add kQKHeaddimForGemmN and kVHeaddimForGemmN in order to support headdim 96
This commit is contained in:
@@ -455,6 +455,8 @@ def get_fmha_bwd_dq_dk_dv_tile_ppl_dict_from_dtype(dtype : str) -> Optional[dict
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'64' : [FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'96' : [FmhaBwdDQDKDVTileSize( 16, 128, 96, 16, 96, 16, 32, 96, 96, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'128' : [FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
"kr_ktr_vr_iglp", "kr_ktr_vr"],
|
||||
'256' : [FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
@@ -801,4 +803,4 @@ def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, mask_im
|
||||
_, kernels = get_bwd_dq_dk_dv_blobs(kernel_filter, receipt, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
|
||||
@@ -51,14 +51,14 @@ struct FmhaBwdDQDKDVKernel
|
||||
using VGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VGradDataType>;
|
||||
using BiasGradDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasGradDataType>;
|
||||
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
static constexpr bool kIsGroupMode = FmhaPipeline::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
|
||||
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
|
||||
using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
|
||||
static constexpr bool kHasMask = FmhaMask::IsMasking;
|
||||
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
|
||||
@@ -776,7 +776,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<1>{});
|
||||
const auto q_dram = pad_tensor_view(
|
||||
q_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
@@ -787,7 +787,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<1>{});
|
||||
const auto k_dram = pad_tensor_view(
|
||||
k_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
|
||||
const auto v_dram = [&]() {
|
||||
@@ -799,7 +799,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<1>{});
|
||||
return pad_tensor_view(
|
||||
v_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
@@ -825,27 +825,27 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<1>{});
|
||||
const auto do_dram = pad_tensor_view(
|
||||
do_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimV>{});
|
||||
|
||||
auto q_dram_window = make_tile_window(
|
||||
q_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
{0, 0});
|
||||
|
||||
auto k_dram_window = make_tile_window(
|
||||
k_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
auto v_dram_window = make_tile_window(
|
||||
v_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
auto do_dram_window = make_tile_window(
|
||||
do_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
|
||||
{0, 0});
|
||||
|
||||
auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
|
||||
@@ -866,16 +866,16 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
return pad_tensor_view(dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
return make_tile_window(dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
{0, 0});
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -894,16 +894,16 @@ struct FmhaBwdDQDKDVKernel
|
||||
number<FmhaPipeline::kAlignmentQGrad>{},
|
||||
number<1>{});
|
||||
|
||||
return pad_tensor_view(
|
||||
dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
return pad_tensor_view(dq_acc_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
return make_tile_window(
|
||||
dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
{0, 0});
|
||||
return make_tile_window(dq_acc_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
{0, 0});
|
||||
}
|
||||
}();
|
||||
|
||||
@@ -1105,7 +1105,7 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
dk_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimQ>{});
|
||||
}();
|
||||
|
||||
@@ -1119,18 +1119,18 @@ struct FmhaBwdDQDKDVKernel
|
||||
|
||||
return pad_tensor_view(
|
||||
dv_dram_naive,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
|
||||
sequence<kPadSeqLenK, kPadHeadDimV>{});
|
||||
}();
|
||||
|
||||
auto dk_dram_window = make_tile_window(
|
||||
dk_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddimForGemmN>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
auto dv_dram_window = make_tile_window(
|
||||
dv_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
|
||||
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddimForGemmN>{}),
|
||||
{i_n0, 0});
|
||||
|
||||
KGradEpiloguePipeline{}(dk_dram_window, dk_acc_tile);
|
||||
|
||||
@@ -38,15 +38,17 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kQKHeaddimForGemmN = BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
static constexpr index_t kVHeaddimForGemmN = BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
@@ -177,8 +179,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
k_lds, make_tuple(number<kN0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
@@ -204,7 +206,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
@@ -227,14 +229,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_lds_read_window =
|
||||
make_tile_window(kt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
|
||||
make_tuple(number<kQKHeaddimForGemmN>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
|
||||
@@ -275,8 +277,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
auto q_lds_window = make_tile_window(
|
||||
q_lds, make_tuple(number<kM0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
@@ -297,14 +299,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto qt_lds_read_window =
|
||||
make_tile_window(qt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
|
||||
make_tuple(number<kQKHeaddimForGemmN>{}, number<kM0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
@@ -321,8 +323,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
auto do_lds_window = make_tile_window(
|
||||
do_lds, make_tuple(number<kM0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
@@ -341,14 +343,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto dot_lds_read_window =
|
||||
make_tile_window(dot_read_lds,
|
||||
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
|
||||
make_tuple(number<kVHeaddimForGemmN>{}, number<kM0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
@@ -485,7 +487,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
//static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
@@ -728,9 +730,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
|
||||
ds_reg_tensor_next = load_tile(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
|
||||
auto kt_reg_tensor_slice =
|
||||
get_slice_tile(kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddimForGemmN, (i_k4 + 1) * kK4>{});
|
||||
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
|
||||
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
|
||||
@@ -38,15 +38,17 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kM0 = BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = BlockFmhaShape::kN0;
|
||||
static constexpr index_t kK0 = BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kK2 = BlockFmhaShape::kK2;
|
||||
static constexpr index_t kK3 = BlockFmhaShape::kK3;
|
||||
static constexpr index_t kK4 = BlockFmhaShape::kK4;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kQKHeaddimForGemmN = BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
static constexpr index_t kVHeaddimForGemmN = BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
|
||||
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
|
||||
@@ -177,8 +179,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto k_lds = make_tensor_view<address_space_enum::lds>(
|
||||
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto k_lds_write_window =
|
||||
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
auto k_lds_write_window = make_tile_window(
|
||||
k_lds, make_tuple(number<kN0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto k_lds_read_window =
|
||||
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
|
||||
@@ -204,7 +206,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto v_lds_write_window =
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
|
||||
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto v_lds_read_window =
|
||||
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
|
||||
@@ -227,14 +229,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_k_lds_write_window = make_tile_window(
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto kt_lds_read_window =
|
||||
make_tile_window(kt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
|
||||
make_tuple(number<kQKHeaddimForGemmN>{}, number<kN0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeKTRegBlockDescriptor<Problem>());
|
||||
|
||||
@@ -274,8 +276,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto q_lds = make_tensor_view<address_space_enum::lds>(
|
||||
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto q_lds_window =
|
||||
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
auto q_lds_window = make_tile_window(
|
||||
q_lds, make_tuple(number<kM0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto q_lds_read_window =
|
||||
make_tile_window(q_lds_window.get_bottom_tensor_view(),
|
||||
@@ -296,14 +298,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_q_lds_write_window = make_tile_window(
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
|
||||
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kQKHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
|
||||
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto qt_lds_read_window =
|
||||
make_tile_window(qt_lds_read,
|
||||
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
|
||||
make_tuple(number<kQKHeaddimForGemmN>{}, number<kM0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
@@ -320,8 +322,8 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
auto do_lds = make_tensor_view<address_space_enum::lds>(
|
||||
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
|
||||
|
||||
auto do_lds_window =
|
||||
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
auto do_lds_window = make_tile_window(
|
||||
do_lds, make_tuple(number<kM0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto do_lds_read_window =
|
||||
make_tile_window(do_lds_window.get_bottom_tensor_view(),
|
||||
@@ -340,14 +342,14 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
|
||||
|
||||
auto shuffled_do_lds_write_window = make_tile_window(
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
|
||||
shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kVHeaddimForGemmN>{}), {0, 0});
|
||||
|
||||
auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
|
||||
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
|
||||
|
||||
auto dot_lds_read_window =
|
||||
make_tile_window(dot_read_lds,
|
||||
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
|
||||
make_tuple(number<kVHeaddimForGemmN>{}, number<kM0>{}),
|
||||
{0, 0},
|
||||
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
|
||||
|
||||
@@ -484,7 +486,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
index_t seqlen_q_step = seqlen_q_start;
|
||||
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
|
||||
static_assert(kM0 == kK1, "kM0 should equal to kK1");
|
||||
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
//static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
|
||||
static_assert(kM0 == kK3, "kM0 should equal to kK3");
|
||||
constexpr index_t k4_loops = kN0 / kK4;
|
||||
|
||||
@@ -763,9 +765,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
ds_reg_tensor_next = load_tile(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
|
||||
auto kt_reg_tensor_slice =
|
||||
get_slice_tile(kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddimForGemmN, (i_k4 + 1) * kK4>{});
|
||||
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
|
||||
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
@@ -999,8 +1002,10 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
|
||||
ds_reg_tensor_next = load_tile(ds_lds_read_window);
|
||||
move_tile_window(ds_lds_read_window, {0, kK4});
|
||||
}
|
||||
auto kt_reg_tensor_slice = get_slice_tile(
|
||||
kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
|
||||
auto kt_reg_tensor_slice =
|
||||
get_slice_tile(kt_reg_tensor,
|
||||
sequence<0, i_k4 * kK4>{},
|
||||
sequence<kQKHeaddimForGemmN, (i_k4 + 1) * kK4>{});
|
||||
|
||||
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
|
||||
if constexpr(i_k4 < k4_loops - 1)
|
||||
|
||||
@@ -63,7 +63,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::OGradDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kVHeaddim,
|
||||
Problem::BlockFmhaShape::kVHeaddimForGemmN,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
@@ -128,7 +128,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::QDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kQKHeaddimForGemmN,
|
||||
Problem::BlockFmhaShape::kK3>,
|
||||
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
|
||||
@@ -160,7 +160,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
typename Problem::KDataType,
|
||||
typename Problem::AccDataType,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kQKHeaddim,
|
||||
Problem::BlockFmhaShape::kQKHeaddimForGemmN,
|
||||
Problem::BlockFmhaShape::kK4>,
|
||||
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
|
||||
@@ -191,7 +191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
|
||||
|
||||
@@ -210,7 +210,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
|
||||
|
||||
@@ -229,7 +229,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
|
||||
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -249,7 +249,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
|
||||
|
||||
@@ -310,7 +310,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -322,7 +322,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
return total_pixels / GetAlignmentK<Problem>();
|
||||
@@ -333,7 +333,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
|
||||
@@ -371,7 +371,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -417,7 +417,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -440,7 +440,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -811,7 +811,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
|
||||
@@ -935,7 +935,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
@@ -961,7 +961,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentK<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -982,7 +982,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKLdsWriteBlockDescriptor()
|
||||
{
|
||||
// Hold all data
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
@@ -994,7 +994,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsReadBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
auto shuffled_k_lds_block_desc = MakeShuffledKLdsWriteBlockDescriptor<Problem>();
|
||||
@@ -1017,7 +1017,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
@@ -1043,7 +1043,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
@@ -1087,7 +1087,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentQ<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1108,7 +1108,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
|
||||
{
|
||||
// Hold full block data
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
@@ -1121,7 +1121,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsReadBlockDescriptor()
|
||||
{
|
||||
// Hold full block data
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
auto shuffled_q_lds_block_desc = MakeShuffledQLdsWriteBlockDescriptor<Problem>();
|
||||
@@ -1144,7 +1144,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddimForGemmN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
@@ -1250,7 +1250,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
// Hold full block data
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
|
||||
@@ -1294,7 +1294,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
|
||||
constexpr index_t K1 = GetAlignmentOGrad<Problem>();
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
@@ -1315,7 +1315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor()
|
||||
{
|
||||
// Hold all data
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
|
||||
@@ -1328,7 +1328,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsReadBlockDescriptor()
|
||||
{
|
||||
// Hold all data
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
auto shuffled_do_lds_block_desc = MakeShuffledOGradLdsWriteBlockDescriptor<Problem>();
|
||||
|
||||
@@ -1350,7 +1350,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddimForGemmN;
|
||||
// constexpr index_t kNPerBlock = 32;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
@@ -1849,12 +1849,16 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
|
||||
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
|
||||
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
|
||||
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kQKHeaddimForGemmN = Problem::BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kVHeaddimForGemmN = Problem::BlockFmhaShape::kVHeaddim;
|
||||
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
|
||||
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
|
||||
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
|
||||
|
||||
static constexpr index_t WarpGemmM =
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
|
||||
@@ -1868,34 +1872,34 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
// Compute
|
||||
static constexpr index_t Gemm0MFMA =
|
||||
kM0 * kN0 * kQKHeaddim /
|
||||
kM0 * kN0 * kK0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm1MFMA =
|
||||
kM0 * kN0 * kVHeaddim /
|
||||
kN0 * kVHeaddimForGemmN * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm2MFMA =
|
||||
kN0 * kVHeaddim * kM0 /
|
||||
kM0 * kN0 * kK2 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm3MFMA =
|
||||
kN0 * kQKHeaddim * kM0 /
|
||||
kN0 * kQKHeaddimForGemmN * kM0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
static constexpr index_t Gemm4MFMA =
|
||||
kM0 * kQKHeaddim * kN0 /
|
||||
kM0 * kQKHeaddimForGemmN * kN0 /
|
||||
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
|
||||
|
||||
// VMEM
|
||||
static constexpr index_t Q_VMEM_READ =
|
||||
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
|
||||
kM0 * kQKHeaddimForGemmN / kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t OGrad_VMEM_READ =
|
||||
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
kM0 * kVHeaddimForGemmN / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t LSE_VMEM_READ = 1;
|
||||
static constexpr index_t D_VMEM_READ = 1;
|
||||
|
||||
// LDS Read
|
||||
static constexpr index_t OGradT_LDS_READ =
|
||||
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
|
||||
kM0 * kVHeaddimForGemmN / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
|
||||
static constexpr index_t QT_LDS_READ =
|
||||
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
kM0 * kQKHeaddimForGemmN / get_warp_size() / GetTransposedAlignmentQ<Problem>();
|
||||
static constexpr index_t SGradT_LDS_READ_P1 =
|
||||
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
|
||||
static constexpr index_t Q_LDS_READ =
|
||||
@@ -1909,13 +1913,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
|
||||
|
||||
// LDS Write
|
||||
static constexpr index_t Q_LDS_WRITE =
|
||||
kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ<Problem>();
|
||||
kM0 * kQKHeaddimForGemmN / Problem::kBlockSize / GetAlignmentQ<Problem>();
|
||||
static constexpr index_t QT_LDS_WRITE =
|
||||
kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ<Problem>();
|
||||
kM0 * kQKHeaddimForGemmN / kBlockSize / GetTransposedAlignmentQ<Problem>();
|
||||
static constexpr index_t OGrad_LDS_WRITE =
|
||||
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
kM0 * kVHeaddimForGemmN / kBlockSize / GetAlignmentOGrad<Problem>();
|
||||
static constexpr index_t OGradT_LDS_WRITE =
|
||||
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
|
||||
kM0 * kVHeaddimForGemmN / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
|
||||
static constexpr index_t LSE_LDS_WRITE = 1;
|
||||
static constexpr index_t D_LDS_WRITE = 1;
|
||||
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
|
||||
|
||||
@@ -55,13 +55,14 @@ struct BlockFmhaBwdPipelineProblem
|
||||
static constexpr bool kIsDeterministic = kIsDeterministic_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
|
||||
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
|
||||
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
|
||||
static constexpr bool kPadHeadDimDoDv = Traits::kPadHeadDimDoDv;
|
||||
static constexpr auto BiasEnum = Traits::BiasEnum;
|
||||
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
|
||||
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
|
||||
};
|
||||
|
||||
template <typename ODataType_,
|
||||
|
||||
@@ -7,6 +7,20 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
|
||||
{
|
||||
if(len == 96)
|
||||
return 128;
|
||||
if(len == 160)
|
||||
return 256;
|
||||
|
||||
// only length of 96, 160 and power-of-two is supported
|
||||
if(!(len & (len - 1)))
|
||||
return len;
|
||||
|
||||
return 0;
|
||||
};
|
||||
|
||||
template <typename BlockTile_, // sequence<...
|
||||
typename Gemm0BlockWarps_,
|
||||
typename Gemm0WarpTile_,
|
||||
@@ -90,6 +104,13 @@ struct TileFmhaBwdShape
|
||||
// K/K^T at once
|
||||
static constexpr index_t kVHeaddim = BlockTile::at(number<8>{}); // V headdim, used for pipeline
|
||||
// that need load V at once
|
||||
|
||||
static constexpr index_t kQKHeaddimForGemmN = ceil_to_qualified_tile_length(kQKHeaddim);
|
||||
|
||||
static constexpr index_t kVHeaddimForGemmN = ceil_to_qualified_tile_length(kVHeaddim);
|
||||
|
||||
static_assert(kQKHeaddimForGemmN > 0, "Check failed!");
|
||||
static_assert(kVHeaddimForGemmN > 0, "Check failed!");
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -33,6 +33,30 @@ struct TileFmhaTraits
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
bool kPadHeadDimV_ /* paddding for hdim_v */,
|
||||
BlockAttentionBiasEnum BiasEnum_,
|
||||
bool kHasBiasGrad_,
|
||||
bool kStoreLSE_,
|
||||
bool kHasDropout_,
|
||||
bool kDoFp8StaticQuant_,
|
||||
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
|
||||
struct TileFmhaBwdTraits
|
||||
{
|
||||
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
|
||||
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
|
||||
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
|
||||
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
|
||||
static constexpr auto BiasEnum = BiasEnum_;
|
||||
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
|
||||
static constexpr bool kStoreLSE = kStoreLSE_;
|
||||
static constexpr bool kHasDropout = kHasDropout_;
|
||||
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
|
||||
static constexpr index_t kBlockPerCu = kBlockPerCu_;
|
||||
};
|
||||
|
||||
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
|
||||
bool kPadSeqLenK_ /* padding for seqlen_k */,
|
||||
bool kPadHeadDimQ_ /* paddding for hdim_q */,
|
||||
|
||||
Reference in New Issue
Block a user