mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[rocm-libraries] ROCm/rocm-libraries#4368 (commit 17f7dfc)
[CK_TILE][FMHA] Support microscaling (mxfp8 and mxfp4) on gfx950 (#4368) ## Motivation Microscaling types (mxfp8 and mxfp4) for fwd qr pipeline ## Technical Details The microscaling is used when quant scale mode is `BlockAttentionQuantScaleEnum::MX` and `Q/K/P/VDataType` are fp8/bf8/fp4. Supported features: * only "qr" pipeline is implemented * hdim 128 and 256 (smaller hdim are not possible due to restrictions of "qr" pipeline, but they can be computed using instances with padding) * both 32x32x64 and 16x16x128 scale MFMAs are supported * Q and K scales are applied in hdim, V scales - in seqlen dimension * column-major V only * batch and group mode * bias, Alibi (tested but no instances by default, just like fp8) * masking etc. Aiter PR with new API args: https://github.com/ROCm/aiter/pull/2008 ## Test Plan ``` ninja test_ck_tile_fmha_fwd_mxfp8 && bin/test_ck_tile_fmha_fwd_mxfp8 ninja test_ck_tile_fmha_fwd_mxfp4 && bin/test_ck_tile_fmha_fwd_mxfp4 ``` ## Test Result The tests must pass. ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This commit is contained in:
committed by
assistant-librarian[bot]
parent
c85c272c39
commit
2312eef6c3
@@ -191,6 +191,29 @@ struct FmhaFwdKernel
|
||||
const int32_t* block_scale_seqstart_k_ptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonMXKargs : FmhaFwdCommonQScaleKargs
|
||||
{
|
||||
ck_tile::index_t stride_q_descale;
|
||||
ck_tile::index_t stride_k_descale;
|
||||
ck_tile::index_t stride_v_descale;
|
||||
|
||||
ck_tile::index_t nhead_stride_q_descale;
|
||||
ck_tile::index_t nhead_stride_k_descale;
|
||||
ck_tile::index_t nhead_stride_v_descale;
|
||||
};
|
||||
|
||||
struct FmhaFwdBatchMXKargs : FmhaFwdCommonMXKargs
|
||||
{
|
||||
ck_tile::index_t batch_stride_q_descale;
|
||||
ck_tile::index_t batch_stride_k_descale;
|
||||
ck_tile::index_t batch_stride_v_descale;
|
||||
};
|
||||
|
||||
struct FmhaFwdGroupMXKargs : FmhaFwdCommonMXKargs
|
||||
{
|
||||
const int32_t* seqstart_v_scale_ptr;
|
||||
};
|
||||
|
||||
struct FmhaFwdCommonLSEKargs
|
||||
{
|
||||
void* lse_ptr = nullptr;
|
||||
@@ -271,7 +294,9 @@ struct FmhaFwdKernel
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdBatchBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::MX,
|
||||
FmhaFwdBatchMXKargs,
|
||||
FmhaFwdEmptyKargs<3>>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdBatchModeDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>
|
||||
{
|
||||
@@ -300,7 +325,9 @@ struct FmhaFwdKernel
|
||||
FmhaFwdCommonQScaleKargs,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE,
|
||||
FmhaFwdGroupBlockScaleKargs,
|
||||
FmhaFwdEmptyKargs<3>>>,
|
||||
std::conditional_t<QScaleEnum == BlockAttentionQuantScaleEnum::MX,
|
||||
FmhaFwdGroupMXKargs,
|
||||
FmhaFwdEmptyKargs<3>>>>,
|
||||
std::conditional_t<kHasDropout, FmhaFwdCommonDropoutKargs, FmhaFwdEmptyKargs<4>>,
|
||||
std::conditional_t<kHasLogitsSoftCap, FmhaFwdLogitsSoftCapKargs, FmhaFwdEmptyKargs<5>>,
|
||||
std::conditional_t<kSkipMinSeqlenQ, FmhaFwdSkipMinSeqlenQKargs, FmhaFwdEmptyKargs<6>>
|
||||
@@ -350,6 +377,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t stride_q_descale,
|
||||
ck_tile::index_t stride_k_descale,
|
||||
ck_tile::index_t stride_v_descale,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
@@ -450,7 +480,7 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
@@ -467,6 +497,24 @@ struct FmhaFwdKernel
|
||||
kargs.block_scale_size_q = block_scale_size_q;
|
||||
kargs.block_scale_size_kv = block_scale_size_kv;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.stride_q_descale = stride_q_descale;
|
||||
kargs.stride_k_descale = stride_k_descale;
|
||||
kargs.stride_v_descale = stride_v_descale;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.batch_stride_q_descale = batch_stride_q_descale;
|
||||
kargs.batch_stride_k_descale = batch_stride_k_descale;
|
||||
kargs.batch_stride_v_descale = batch_stride_v_descale;
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -525,6 +573,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t stride_q_descale,
|
||||
ck_tile::index_t stride_k_descale,
|
||||
ck_tile::index_t stride_v_descale,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
@@ -583,6 +634,9 @@ struct FmhaFwdKernel
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
stride_q_descale,
|
||||
stride_k_descale,
|
||||
stride_v_descale,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
@@ -644,6 +698,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t stride_q_descale,
|
||||
ck_tile::index_t stride_k_descale,
|
||||
ck_tile::index_t stride_v_descale,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
@@ -702,6 +759,9 @@ struct FmhaFwdKernel
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
stride_q_descale,
|
||||
stride_k_descale,
|
||||
stride_v_descale,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
@@ -754,6 +814,7 @@ struct FmhaFwdKernel
|
||||
const void* seqlen_k_ptr,
|
||||
const void* block_scale_seqstart_q_ptr,
|
||||
const void* block_scale_seqstart_k_ptr,
|
||||
const void* seqstart_v_scale_ptr,
|
||||
ck_tile::index_t hdim_q,
|
||||
ck_tile::index_t hdim_v,
|
||||
ck_tile::index_t num_head_q,
|
||||
@@ -766,6 +827,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t stride_q_descale,
|
||||
ck_tile::index_t stride_k_descale,
|
||||
ck_tile::index_t stride_v_descale,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
@@ -856,7 +920,7 @@ struct FmhaFwdKernel
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
@@ -874,6 +938,22 @@ struct FmhaFwdKernel
|
||||
kargs.block_scale_seqstart_k_ptr =
|
||||
reinterpret_cast<const int32_t*>(block_scale_seqstart_k_ptr);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
{
|
||||
kargs.q_descale_ptr = q_descale_ptr;
|
||||
kargs.k_descale_ptr = k_descale_ptr;
|
||||
kargs.v_descale_ptr = v_descale_ptr;
|
||||
|
||||
kargs.stride_q_descale = stride_q_descale;
|
||||
kargs.stride_k_descale = stride_k_descale;
|
||||
kargs.stride_v_descale = stride_v_descale;
|
||||
|
||||
kargs.nhead_stride_q_descale = nhead_stride_q_descale;
|
||||
kargs.nhead_stride_k_descale = nhead_stride_k_descale;
|
||||
kargs.nhead_stride_v_descale = nhead_stride_v_descale;
|
||||
|
||||
kargs.seqstart_v_scale_ptr = reinterpret_cast<const int32_t*>(seqstart_v_scale_ptr);
|
||||
}
|
||||
if constexpr(kHasDropout)
|
||||
{
|
||||
if(drop_seed_offset.index() == 0) // seed & offset come from host
|
||||
@@ -939,6 +1019,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t stride_q_descale,
|
||||
ck_tile::index_t stride_k_descale,
|
||||
ck_tile::index_t stride_v_descale,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
@@ -992,6 +1075,9 @@ struct FmhaFwdKernel
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
stride_q_descale,
|
||||
stride_k_descale,
|
||||
stride_v_descale,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
@@ -1048,6 +1134,9 @@ struct FmhaFwdKernel
|
||||
ck_tile::index_t stride_bias,
|
||||
ck_tile::index_t stride_randval,
|
||||
ck_tile::index_t stride_o,
|
||||
ck_tile::index_t stride_q_descale,
|
||||
ck_tile::index_t stride_k_descale,
|
||||
ck_tile::index_t stride_v_descale,
|
||||
ck_tile::index_t nhead_stride_q,
|
||||
ck_tile::index_t nhead_stride_k,
|
||||
ck_tile::index_t nhead_stride_v,
|
||||
@@ -1101,6 +1190,9 @@ struct FmhaFwdKernel
|
||||
stride_bias,
|
||||
stride_randval,
|
||||
stride_o,
|
||||
stride_q_descale,
|
||||
stride_k_descale,
|
||||
stride_v_descale,
|
||||
nhead_stride_q,
|
||||
nhead_stride_k,
|
||||
nhead_stride_v,
|
||||
@@ -1303,6 +1395,12 @@ struct FmhaFwdKernel
|
||||
batch_offset_k_descale = bkey_start;
|
||||
batch_offset_v_descale = bkey_start;
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
{
|
||||
batch_offset_q_descale = query_start * kargs.stride_q_descale;
|
||||
batch_offset_k_descale = key_start * kargs.stride_k_descale;
|
||||
batch_offset_v_descale = kargs.seqstart_v_scale_ptr[i_batch];
|
||||
}
|
||||
batch_offset_o = query_start * kargs.stride_o;
|
||||
|
||||
// real logical lengths (exclude PAD)
|
||||
@@ -1370,7 +1468,8 @@ struct FmhaFwdKernel
|
||||
batch_offset_randval =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
|
||||
}
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE ||
|
||||
QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
{
|
||||
batch_offset_q_descale =
|
||||
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q_descale;
|
||||
@@ -1395,17 +1494,20 @@ struct FmhaFwdKernel
|
||||
}
|
||||
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
|
||||
|
||||
const QDataType* q_ptr =
|
||||
reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
(static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q + batch_offset_q) /
|
||||
numeric_traits<QDataType>::PackedSize;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
(static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k + batch_offset_k) /
|
||||
numeric_traits<KDataType>::PackedSize;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
(static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v + batch_offset_v) /
|
||||
numeric_traits<VDataType>::PackedSize;
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
batch_offset_o;
|
||||
@@ -1698,9 +1800,9 @@ struct FmhaFwdKernel
|
||||
}
|
||||
}();
|
||||
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead / kargs.nhead_ratio_qk};
|
||||
BlockIndices block_indices{i_batch, i_nhead, i_nhead_k};
|
||||
|
||||
auto o_acc_tile = [&, i_nhead_ = i_nhead]() {
|
||||
auto o_acc_tile = [&, i_nhead_ = i_nhead, i_nhead_k_ = i_nhead_k]() {
|
||||
if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::PERTENSOR)
|
||||
{
|
||||
// TODO - move global load of descale to pipeline
|
||||
@@ -1744,6 +1846,9 @@ struct FmhaFwdKernel
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_value);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::BLOCKSCALE)
|
||||
@@ -1795,8 +1900,144 @@ struct FmhaFwdKernel
|
||||
k_descale_ptr,
|
||||
v_descale_ptr,
|
||||
kargs.block_scale_size_kv,
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
make_null_tile_window(make_tuple()),
|
||||
sink_value);
|
||||
}
|
||||
else if constexpr(QScaleEnum == BlockAttentionQuantScaleEnum::MX)
|
||||
{
|
||||
using QScaleDataType = typename FmhaPipeline::QScaleDataType;
|
||||
using KScaleDataType = typename FmhaPipeline::KScaleDataType;
|
||||
using VScaleDataType = typename FmhaPipeline::VScaleDataType;
|
||||
|
||||
constexpr ck_tile::index_t kQKScaleGranularity =
|
||||
FmhaPipeline::kQKScaleGranularity;
|
||||
constexpr ck_tile::index_t kVScaleGranularity =
|
||||
FmhaPipeline::kVScaleGranularity;
|
||||
|
||||
const QScaleDataType* q_descale_ptr =
|
||||
reinterpret_cast<const QScaleDataType*>(kargs.q_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_q_descale +
|
||||
batch_offset_q_descale;
|
||||
const KScaleDataType* k_descale_ptr =
|
||||
reinterpret_cast<const KScaleDataType*>(kargs.k_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_k_) * kargs.nhead_stride_k_descale +
|
||||
batch_offset_k_descale;
|
||||
const VScaleDataType* v_descale_ptr =
|
||||
reinterpret_cast<const VScaleDataType*>(kargs.v_descale_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_k_) * kargs.nhead_stride_v_descale +
|
||||
batch_offset_v_descale;
|
||||
|
||||
const ck_tile::index_t hdim_q_scale =
|
||||
ck_tile::integer_divide_ceil(kargs.hdim_q, kQKScaleGranularity);
|
||||
const ck_tile::index_t seqlen_v_scale =
|
||||
ck_tile::integer_divide_ceil(kargs.seqlen_k, kVScaleGranularity);
|
||||
|
||||
// Custom invalid_element_value is required for e8m0_t scales because
|
||||
// the default (numeric<e8m0_t>>::zero()) is NaN
|
||||
const auto q_scale_dram = [&]() {
|
||||
auto desc =
|
||||
make_naive_tensor_descriptor(make_tuple(kargs.seqlen_q, hdim_q_scale),
|
||||
make_tuple(kargs.stride_q_descale, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
q_descale_ptr,
|
||||
desc.get_element_space_size(),
|
||||
type_convert<QScaleDataType>(1.0f));
|
||||
return pad_tensor_view(
|
||||
tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc},
|
||||
make_tuple(
|
||||
number<FmhaPipeline::kM0>{},
|
||||
number<(FmhaPipeline::kQLoadOnce ? FmhaPipeline::kSubQKHeaddim
|
||||
: FmhaPipeline::kK0) /
|
||||
kQKScaleGranularity>{}),
|
||||
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto k_scale_dram = [&]() {
|
||||
auto desc =
|
||||
make_naive_tensor_descriptor(make_tuple(kargs.seqlen_k, hdim_q_scale),
|
||||
make_tuple(kargs.stride_k_descale, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
k_descale_ptr,
|
||||
desc.get_element_space_size(),
|
||||
type_convert<KScaleDataType>(1.0f));
|
||||
return pad_tensor_view(
|
||||
tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc},
|
||||
make_tuple(number<FmhaPipeline::kN0>{},
|
||||
number<FmhaPipeline::kK0 / kQKScaleGranularity>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
}();
|
||||
const auto v_scale_dram = [&]() {
|
||||
static_assert(
|
||||
std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
auto desc =
|
||||
make_naive_tensor_descriptor(make_tuple(kargs.hdim_v, seqlen_v_scale),
|
||||
make_tuple(kargs.stride_v_descale, 1),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
auto buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
v_descale_ptr,
|
||||
desc.get_element_space_size(),
|
||||
type_convert<VScaleDataType>(1.0f));
|
||||
return pad_tensor_view(
|
||||
tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc},
|
||||
make_tuple(number<FmhaPipeline::kN1>{},
|
||||
number<FmhaPipeline::kK1 / kVScaleGranularity>{}),
|
||||
sequence<false, kPadSeqLenK>{});
|
||||
}();
|
||||
|
||||
auto q_scale_dram_window = make_tile_window(
|
||||
q_scale_dram,
|
||||
make_tuple(number<FmhaPipeline::kM0>{},
|
||||
number<(FmhaPipeline::kQLoadOnce ? FmhaPipeline::kSubQKHeaddim
|
||||
: FmhaPipeline::kK0) /
|
||||
kQKScaleGranularity>{}),
|
||||
{i_m0, 0});
|
||||
auto k_scale_dram_window = make_tile_window(
|
||||
k_scale_dram,
|
||||
make_tuple(number<FmhaPipeline::kN0>{},
|
||||
number<FmhaPipeline::kK0 / kQKScaleGranularity>{}),
|
||||
{0, 0});
|
||||
auto v_scale_dram_window = make_tile_window(
|
||||
v_scale_dram,
|
||||
make_tuple(number<FmhaPipeline::kN1>{},
|
||||
number<FmhaPipeline::kK1 / kVScaleGranularity>{}),
|
||||
{i_n1, 0});
|
||||
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
identity{}, // q_element_func
|
||||
k_dram_window,
|
||||
identity{}, // k_element_func
|
||||
v_dram_window,
|
||||
identity{}, // v_element_func
|
||||
bias_dram_window,
|
||||
identity{}, // bias_element_func
|
||||
randval_dram_window,
|
||||
lse_dram_window,
|
||||
identity{}, // lse_element_func
|
||||
identity{}, // s_acc_element_func
|
||||
identity{}, // p_compute_element_func
|
||||
identity{}, // o_acc_element_func
|
||||
mask,
|
||||
position_encoding,
|
||||
kargs.scale_s,
|
||||
variant,
|
||||
variant_params,
|
||||
block_indices,
|
||||
smem_ptr,
|
||||
dropout,
|
||||
nullptr,
|
||||
nullptr,
|
||||
1,
|
||||
q_scale_dram_window,
|
||||
k_scale_dram_window,
|
||||
v_scale_dram_window,
|
||||
sink_value);
|
||||
}
|
||||
else
|
||||
{
|
||||
return FmhaPipeline{}(q_dram_window,
|
||||
@@ -1969,15 +2210,18 @@ struct FmhaFwdKernel
|
||||
// for simplicity, batch stride we just modify the pointer
|
||||
const index_t i_nhead_k = i_nhead / kargs.nhead_ratio_qk;
|
||||
|
||||
const QDataType* q_ptr = reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
|
||||
batch_offset_q;
|
||||
const KDataType* k_ptr = reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k +
|
||||
batch_offset_k;
|
||||
const VDataType* v_ptr = reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v +
|
||||
batch_offset_v;
|
||||
const QDataType* q_ptr =
|
||||
reinterpret_cast<const QDataType*>(kargs.q_ptr) +
|
||||
(static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q + batch_offset_q) /
|
||||
numeric_traits<QDataType>::PackedSize;
|
||||
const KDataType* k_ptr =
|
||||
reinterpret_cast<const KDataType*>(kargs.k_ptr) +
|
||||
(static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_k + batch_offset_k) /
|
||||
numeric_traits<KDataType>::PackedSize;
|
||||
const VDataType* v_ptr =
|
||||
reinterpret_cast<const VDataType*>(kargs.v_ptr) +
|
||||
(static_cast<long_index_t>(i_nhead_k) * kargs.nhead_stride_v + batch_offset_v) /
|
||||
numeric_traits<VDataType>::PackedSize;
|
||||
|
||||
ODataType* o_ptr = reinterpret_cast<ODataType*>(kargs.o_ptr) +
|
||||
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o +
|
||||
@@ -2006,7 +2250,8 @@ struct FmhaFwdKernel
|
||||
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
|
||||
sequence<false, kPadHeadDimQ>{});
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(QDataType);
|
||||
constexpr index_t LDSLayerSize =
|
||||
256 * numeric_traits<QDataType>::PackedSize / sizeof(QDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
|
||||
if constexpr(XorLengthFold > 1)
|
||||
@@ -2130,7 +2375,8 @@ struct FmhaFwdKernel
|
||||
FmhaPipeline::kKLoadOnce ? FmhaPipeline::kQKHeaddim : FmhaPipeline::kK0;
|
||||
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(KDataType);
|
||||
constexpr index_t LDSLayerSize =
|
||||
256 * numeric_traits<KDataType>::PackedSize / sizeof(KDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
|
||||
if constexpr(XorLengthFold > 1)
|
||||
@@ -2254,7 +2500,8 @@ struct FmhaFwdKernel
|
||||
sequence<kPadSeqLenK, false>{});
|
||||
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr index_t LDSLayerSize = 256 / sizeof(VDataType);
|
||||
constexpr index_t LDSLayerSize =
|
||||
256 * numeric_traits<VDataType>::PackedSize / sizeof(VDataType);
|
||||
constexpr index_t XorLengthFold = LDSLayerSize / (FmhaPipeline::kQKHeaddim);
|
||||
|
||||
if constexpr(XorLengthFold > 1)
|
||||
|
||||
Reference in New Issue
Block a user