mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Fixes
This commit is contained in:
@@ -1,114 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
|
||||
// keep sync with BlockAttentionBiasEnum
|
||||
enum class bias_enum
|
||||
{
|
||||
no_bias = 0,
|
||||
elementwise_bias = 1,
|
||||
alibi = 2,
|
||||
};
|
||||
|
||||
struct bias_info
|
||||
{
|
||||
bias_enum type;
|
||||
/*
|
||||
* simple dispatch logic
|
||||
*
|
||||
* if type == elementwise_bias:
|
||||
* if rank_info == 0:
|
||||
* bias is 1*1*s*s
|
||||
* elif rank_info == 1:
|
||||
* bias is 1*h*s*s
|
||||
* elif rank_info == 2:
|
||||
* bias is b*h*s*s
|
||||
*
|
||||
* elif type == alibi:
|
||||
* if rank_info == 0:
|
||||
* alibi in 1*h
|
||||
* elif rank_info == 1:
|
||||
* alibi in b*h
|
||||
*/
|
||||
int rank_info;
|
||||
|
||||
void serialize(std::ostream& os) const
|
||||
{
|
||||
if(type == bias_enum::no_bias)
|
||||
os << "n";
|
||||
else if(type == bias_enum::elementwise_bias)
|
||||
{
|
||||
os << "e";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
else if(type == bias_enum::alibi)
|
||||
{
|
||||
os << "alibi";
|
||||
if(rank_info != 0)
|
||||
{
|
||||
os << "[" << rank_info << "]";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static bias_info decode(std::string str)
|
||||
{
|
||||
bias_info info{bias_enum::no_bias, 0};
|
||||
auto found_0 = str.find(':');
|
||||
if(found_0 != std::string::npos)
|
||||
{
|
||||
std::string t = str.substr(0, found_0);
|
||||
std::string v = str.substr(found_0 + 1);
|
||||
if(t == "e" || t == "elementwise")
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
info.rank_info = std::stoi(v);
|
||||
if(info.rank_info < 0 || info.rank_info > 2)
|
||||
throw std::invalid_argument("invalid bias rank: " + str);
|
||||
}
|
||||
else if(t == "a" || t == "alibi")
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
info.rank_info = std::stoi(v);
|
||||
if(info.rank_info < 0 || info.rank_info > 1)
|
||||
throw std::invalid_argument("invalid bias rank: " + str);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid bias value: " + str);
|
||||
}
|
||||
}
|
||||
else if(str == "0" || str == "n")
|
||||
{
|
||||
info.type = bias_enum::no_bias;
|
||||
}
|
||||
else if(str == "1" || str == "e" || str == "elementwise")
|
||||
{
|
||||
info.type = bias_enum::elementwise_bias;
|
||||
}
|
||||
else if(str == "2" || str == "a" || str == "alibi")
|
||||
{
|
||||
info.type = bias_enum::alibi;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::invalid_argument("invalid bias value: " + str);
|
||||
}
|
||||
return info;
|
||||
}
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& os, const bias_info& bi)
|
||||
{
|
||||
bi.serialize(os);
|
||||
return os;
|
||||
}
|
||||
};
|
||||
@@ -1,5 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
GEN_DIR = "" # in Cmake, have to generate files in same folder
|
||||
@@ -1,138 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
FWD_DTYPE_MAP = {
|
||||
"fp32" : "FmhaFwdFp32",
|
||||
"fp16" : "FmhaFwdFp16",
|
||||
"bf16" : "FmhaFwdBf16",
|
||||
"fp8" : "FmhaFwdFp8",
|
||||
"fp8fp16": "FmhaFwdFp8Fp16",
|
||||
"fp8bf16": "FmhaFwdFp8Bf16",
|
||||
"fp8fp32": "FmhaFwdFp8Fp32"
|
||||
}
|
||||
|
||||
BWD_DTYPE_MAP = {
|
||||
"fp32": "FmhaBwdFp32",
|
||||
"fp16": "FmhaBwdFp16",
|
||||
"bf16": "FmhaBwdBf16"
|
||||
}
|
||||
|
||||
MASK_IMPL = {
|
||||
"generic" : "ck_tile::GenericAttentionMask",
|
||||
"simplified" : "ck_tile::SimplifiedGenericAttentionMask"
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_MAP = {
|
||||
"s_no" : "ck_tile::SimplifiedGenericAttentionMask<false>",
|
||||
"s_mask" : "ck_tile::SimplifiedGenericAttentionMask<true>",
|
||||
}
|
||||
|
||||
_MASK_MAP = {
|
||||
"no" : "FmhaMasks::NoMask",
|
||||
"causal" : "FmhaMasks::CausalMask",
|
||||
"generic" : "FmhaMasks::GenericMask"
|
||||
}
|
||||
|
||||
def get_mask_map(mask : str):
|
||||
if mask == "generic":
|
||||
return _MASK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
_MASK_CHECK_MAP = {
|
||||
"no" : "t.mask_type == mask_enum::no_mask",
|
||||
"causal" : "t.mask_type == mask_enum::mask_top_left || t.mask_type == mask_enum::mask_bottom_right",
|
||||
"generic" : "t.mask_type == mask_enum::window_generic",
|
||||
}
|
||||
|
||||
_MASK_SIMPLIFIED_CHECK_MAP = {
|
||||
"s_no" : "t.mask_type == mask_enum::no_mask",
|
||||
"s_mask" : "t.mask_type != mask_enum::no_mask",
|
||||
}
|
||||
|
||||
def get_mask_check_map(mask : str):
|
||||
if mask == "generic":
|
||||
return _MASK_CHECK_MAP
|
||||
elif mask == "simplified":
|
||||
return _MASK_SIMPLIFIED_CHECK_MAP
|
||||
else:
|
||||
assert False
|
||||
return None
|
||||
|
||||
BIAS_MAP = {
|
||||
"no" : "ck_tile::BlockAttentionBiasEnum::NO_BIAS",
|
||||
"bias" : "ck_tile::BlockAttentionBiasEnum::ELEMENTWISE_BIAS",
|
||||
"alibi" : "ck_tile::BlockAttentionBiasEnum::ALIBI"
|
||||
}
|
||||
|
||||
# TODO: this is ugly
|
||||
BIAS_CHECK_MAP = {
|
||||
"no" : "bias_enum::no_bias",
|
||||
"bias" : "bias_enum::elementwise_bias",
|
||||
"alibi" : "bias_enum::alibi"
|
||||
}
|
||||
|
||||
DROPOUT_MAP = {
|
||||
"no" : "ck_tile::BlockDropoutBwd<false, true, false>",
|
||||
"dropout_wg32" : "ck_tile::BlockDropoutBwd<true, true, false>",
|
||||
"dropout_wg32_storerandval" : "ck_tile::BlockDropoutBwd<true, true, true >",
|
||||
"dropout_wg16" : "ck_tile::BlockDropoutBwd<true, false, false>",
|
||||
"dropout_wg16_storerandval" : "ck_tile::BlockDropoutBwd<true, false, true >"
|
||||
}
|
||||
|
||||
DROPOUT_CHECK_MAP = {
|
||||
"no" : "t.has_dropout == false",
|
||||
"dropout_wg32" : "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg32_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
|
||||
"dropout_wg16" : "t.has_dropout == true && t.is_store_randval == false",
|
||||
"dropout_wg16_storerandval" : "t.has_dropout == true && t.is_store_randval == true",
|
||||
}
|
||||
|
||||
ROPE_MAP = {
|
||||
"no" : "ck_tile::RotaryEmbeddingEnum::NONE",
|
||||
"inter" : "ck_tile::RotaryEmbeddingEnum::INTERLEAVED",
|
||||
"half" : "ck_tile::RotaryEmbeddingEnum::HALF_ROTATED"
|
||||
}
|
||||
|
||||
ROPE_CHECK_MAP = {
|
||||
"no" : "rope_enum::none",
|
||||
"inter" : "rope_enum::interleaved",
|
||||
"half" : "rope_enum::half_rotated"
|
||||
}
|
||||
|
||||
MODE_MAP = {
|
||||
"batch" : "false",
|
||||
"group" : "true"
|
||||
}
|
||||
|
||||
LAYOUT_MAP = {
|
||||
"row" : "true",
|
||||
"col" : "false"
|
||||
}
|
||||
|
||||
PIPELINE_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC",
|
||||
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
"qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
"t" : "true",
|
||||
"f" : "false",
|
||||
True : "true",
|
||||
False : "false",
|
||||
}
|
||||
@@ -1,633 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
256: 256
|
||||
}
|
||||
|
||||
FMHA_BATCH_PREFILL_PIPELINE_MAP = {
|
||||
"qr_async" : "ck_tile::BlockFmhaBatchPrefillPipelineQRKSVSAsync",
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
false,
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaBatchPrefillWithPagedKVCacheKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_batch_prefill_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_batch_prefill_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_batch_prefill_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_batch_prefill_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
#include <cstdio>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cu) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cu = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_batch_prefill(fmha_batch_prefill_traits t, fmha_batch_prefill_args a, const ck_tile::stream_config& s) {{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, false>;
|
||||
return fmha_batch_prefill_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return 'true'
|
||||
else:
|
||||
return f'{self.bool_expr}'
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f'({str(self)}) && ({str(other)})')
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
constraint : CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_fp8']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_squant=BOOL_MAP[trait.squant],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_BATCH_PREFILL_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_batch_prefill_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
|
||||
|
||||
class KernelComponentFactory:
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
128 : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
# pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
if 128 in result.keys():
|
||||
result[128].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
|
||||
return result
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = CustomFactory.get_hdim_tile_size_dict(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for (hdim, tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for tile, pipeline in itertools.product(tiles, CustomFactory.get_pipelines(dtype, hdim, receipt, mask_impl)):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_batch_prefill) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_batch_prefill C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
@@ -1,929 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Literal, Any
|
||||
from collections import defaultdict
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
from codegen.utils import update_file
|
||||
|
||||
|
||||
FMHA_BWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "fmha_bwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::
|
||||
sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bk1}, {F_bk2}, {F_bk3}, {F_bk4}, {F_bhdq}, {F_bhdv}>;
|
||||
using fmha_block_warps0_{F_idx} = ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>;
|
||||
using fmha_block_warps1_{F_idx} = ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>;
|
||||
using fmha_block_warps2_{F_idx} = ck_tile::sequence<{F_rm2}, {F_rn2}, {F_rk2}>;
|
||||
using fmha_warp_tile0_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>;
|
||||
using fmha_warp_tile1_{F_idx} = ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>;
|
||||
using fmha_warp_tile2_{F_idx} = ck_tile::sequence<{F_wm0}, {F_wn0}, ck_tile::min({F_wk0}, {F_bk4})>;
|
||||
|
||||
// TODO: simplify Gemm0~4BlockWarps in TileFmhaBwdShape
|
||||
// G0&G2 -> GSdP
|
||||
// G1&G3 -> GdKV
|
||||
// G4 -> GdQ
|
||||
using fmha_bwd_shape_{F_idx} = ck_tile::TileFmhaBwdShape<fmha_block_tile_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile0_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile1_{F_idx},
|
||||
fmha_block_warps0_{F_idx},
|
||||
fmha_warp_tile0_{F_idx},
|
||||
fmha_block_warps1_{F_idx},
|
||||
fmha_warp_tile1_{F_idx},
|
||||
fmha_block_warps2_{F_idx},
|
||||
fmha_warp_tile2_{F_idx},
|
||||
{F_maxq}>;
|
||||
|
||||
using fmha_bwd_trait_{F_idx} = ck_tile::TileFmhaBwdTraits<{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
{F_occupancy}>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
using fmha_dropout_{F_idx} = {F_dropout};
|
||||
|
||||
using fmha_bwd_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::GemmDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::KGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::VGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::BiasGradDataType,
|
||||
fmha_bwd_shape_{F_idx},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_dropout_{F_idx},
|
||||
{F_trload},
|
||||
fmha_bwd_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_pipeline_{F_idx} = ck_tile::BlockFmhaBwdDQDKDVPipeline<fmha_bwd_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dk_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::KGradDataType,
|
||||
false,
|
||||
({F_dpad} > 0)>>;
|
||||
|
||||
using fmha_bwd_dv_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::VGradDataType,
|
||||
false,
|
||||
({F_dvpad} > 0)>>;
|
||||
|
||||
using fmha_bwd_dq_epilogue_{F_idx} = ck_tile::Default2DEpilogue<
|
||||
ck_tile::Default2DEpilogueProblem<typename FmhaBwdTypeConfig<{F_dtype}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<{F_dtype}>::QGradDataType,
|
||||
false,
|
||||
({F_dpad} > 0)>>;
|
||||
|
||||
using fmha_bwd_dq_dk_dv_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdDQDKDVKernel<fmha_bwd_pipeline_{F_idx},
|
||||
fmha_bwd_dk_epilogue_{F_idx},
|
||||
fmha_bwd_dv_epilogue_{F_idx},
|
||||
fmha_bwd_dq_epilogue_{F_idx}>;
|
||||
|
||||
using dq_dk_dv_trait_{F_idx} = fmha_bwd_dq_dk_dv_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
{F_mode},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_dropout_{F_idx},
|
||||
{F_bias},
|
||||
{F_dbias},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_deterministic},
|
||||
{F_trload},
|
||||
{F_maxq},
|
||||
{F_bn0}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dq_dk_dv_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_{F_idx}>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dq_dk_dv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
int fmha_bwd_dq_dk_dv_maxq_<dq_dk_dv_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::kMaxSeqLenQ;
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dq_dk_dv_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_BWD_API_FILENAME="fmha_bwd_api.cpp"
|
||||
FMHA_BWD_API="""
|
||||
#include <iostream>
|
||||
|
||||
template <typename dot_do_o_trait_, typename dq_dk_dv_trait_, typename convert_dq_trait_>
|
||||
float fmha_bwd_(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
if constexpr (!std::is_same_v<convert_dq_trait_, void>)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_convert_dq_get_name_<convert_dq_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_convert_dq_oneshot_<convert_dq_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
else
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_>() << "@" << fmha_bwd_dq_dk_dv_get_name_<dq_dk_dv_trait_>() << std::flush;
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_bwd_dq_dk_dv_oneshot_<dq_dk_dv_trait_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
}}
|
||||
|
||||
template <>
|
||||
float fmha_bwd<2>(fmha_bwd_traits t, fmha_bwd_args a, const ck_tile::stream_config& s){{
|
||||
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
def FMHA_BWD_API_COND_STATEMENT(F_cond: str, F_body: str, *, indent=0, if_ = 0) -> str:
|
||||
lines = [
|
||||
f"{'if' if if_ == 0 else 'else if'}({F_cond})",
|
||||
"{",
|
||||
*[' ' + line for line in F_body.split('\n') if line.strip() != ''],
|
||||
"}",
|
||||
]
|
||||
return '\n'.join(' ' * indent + line for line in lines) + '\n'
|
||||
|
||||
|
||||
FMHA_BWD_API_INNER_DISPATCH="""
|
||||
{F_if}((t.is_group_mode == {F_mode}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_dbias == {F_dbias}) && ({F_dropout_check}) &&
|
||||
({F_scheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.is_deterministic == {F_deterministic}){F_cond_extra}) {{
|
||||
using dot_do_o_trait_ = fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dvpad} > 0)>;
|
||||
using dq_dk_dv_trait_ = fmha_bwd_dq_dk_dv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_mask}, {F_dropout}, {F_bias}, {F_dbias}, {F_dpad}, {F_dvpad}, {F_deterministic}, {F_trload}, {F_maxq}, {F_bn0}>;
|
||||
using convert_dq_trait_ = fmha_bwd_convert_dq_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad1d}, ({F_dpad} > 0), {F_deterministic}, {F_convert_dq_bn0}>;
|
||||
r = fmha_bwd_<dot_do_o_trait_, dq_dk_dv_trait_, std::conditional_t<{F_convert_dq_enabled}, convert_dq_trait_, void>>(s, a);
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
# M0 size for 1d kernels (dot/convert)
|
||||
M0_1D = 64
|
||||
|
||||
# GEMM0: Q@K=S^T
|
||||
# GEMM1: P^T@dO^T=dV(This was chosen as G1 to match fwd, but N1 must be equal to headdim_v)
|
||||
# GEMM2: dO@V=dP^T(This was chosen as G2 because of the calculation order)
|
||||
# GEMM3: dS^T@Q^T=dK(Similar to G1, but N3 must be equal to headdim_qk)
|
||||
# GEMM4: dS@K^T=dQ(N4 must be equal to headdim_qk)
|
||||
# Is it necessary to distinguish between K0~K4?
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along gemm0 unroll(F_bhdq)
|
||||
F_bk1 : int # tile size along gemm1 unroll(F_bm0)
|
||||
F_bk2 : int # tile size along gemm2 unroll(F_bhdv)
|
||||
F_bk3 : int # tile size along gemm3 unroll(F_bm0)
|
||||
F_bk4 : int # tile size along gemm4 unroll(F_bn0)
|
||||
F_bhdq : int # q head_dim
|
||||
F_bhdv : int # v head_dim
|
||||
F_rm0 : int # number of warps along q seqlen (block warps) in gemm0/gemm2
|
||||
F_rn0 : int # number of warps along k seqlen (block warps) in gemm0/gemm2
|
||||
F_rk0 : int # number of warps along headdim_qk/v (not used) in gemm0/gemm2
|
||||
F_rm1 : int # number of warps along k seqlen (block warps) in gemm1/gemm3
|
||||
F_rn1 : int # number of warps along headdim_qk/v (block warps) in gemm1/gemm3
|
||||
F_rk1 : int # number of warps along q seqlen (not used) in gemm1/gemm3
|
||||
F_rm2 : int # number of warps along q seqlen (block warps) in gemm4
|
||||
F_rn2 : int # number of warps along headdim_qk (block warps) in gemm4
|
||||
F_rk2 : int # number of warps along k seqlen (not used) in gemm4
|
||||
F_wm0 : int # warp size along m in gemm0/gemm2/gemm4
|
||||
F_wn0 : int # warp size along n in gemm0/gemm2/gemm4
|
||||
F_wk0 : int # warp size along k in gemm0/gemm2/gemm4
|
||||
F_wm1 : int # warp size along m in gemm1/gemm3
|
||||
F_wn1 : int # warp size along n in gemm1/gemm3
|
||||
F_wk1 : int # warp size along k in gemm1/gemm3
|
||||
F_occupancy : int # occupancy
|
||||
max_seq_q : int = 0
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bk1}x{self.F_bk2}x{self.F_bk3}x{self.F_bk4}x{self.F_bhdq}x{self.F_bhdv}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}_r{self.F_rm2}x{self.F_rn2}x{self.F_rk2}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}_o{self.F_occupancy}_maxq{self.max_seq_q}"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdDQDKDVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaBwdDQDKDVTileSize
|
||||
F_dpad : Literal[0, 8 ,1]
|
||||
F_dvpad : Literal[0, 8 ,1]
|
||||
F_bias : str #
|
||||
F_dbias : str #
|
||||
F_dropout : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_deterministic : str #
|
||||
mask_impl : str #
|
||||
F_trload : str #
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_DQ_DK_DV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk2 = self.F_tile.F_bk2,
|
||||
F_bk3 = self.F_tile.F_bk3,
|
||||
F_bk4 = self.F_tile.F_bk4,
|
||||
F_bhdq = self.F_tile.F_bhdq,
|
||||
F_bhdv = self.F_tile.F_bhdv,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_rm2 = self.F_tile.F_rm2,
|
||||
F_rn2 = self.F_tile.F_rn2,
|
||||
F_rk2 = self.F_tile.F_rk2,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_dpad = self.F_dpad,
|
||||
F_dvpad = self.F_dvpad,
|
||||
F_bias = BIAS_MAP[self.F_bias],
|
||||
F_dbias = BOOL_MAP[self.F_dbias],
|
||||
F_dropout = DROPOUT_MAP[self.F_dropout],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic],
|
||||
F_trload = BOOL_MAP[self.F_trload],
|
||||
F_maxq = self.F_tile.max_seq_q
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_dpad : n += f'd{self.F_dpad}'
|
||||
if self.F_dvpad : n += f'dv{self.F_dvpad}'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + self.F_tile.name
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_dbias == 't' : n += '_dbias'
|
||||
else: n += '_ndbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_dropout != 'no' : n += f'_{self.F_dropout}'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_deterministic == 't' : n += '_deterministic'
|
||||
else: n += '_ndeterministic'
|
||||
|
||||
if self.F_trload == 't' : n += '_trload'
|
||||
else: n += '_ntrload'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size.
|
||||
def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
|
||||
if dtype == 'fp32' and tr_load == 'f':
|
||||
return [
|
||||
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
|
||||
]
|
||||
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 96, 32, 96, 32, 32, 96, 96, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 128, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 32, 64, 160, 32, 160, 32, 32, 160, 160, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 64, 256, 16, 256, 16, 32, 256, 256, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
]
|
||||
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 't':
|
||||
return [
|
||||
FmhaBwdDQDKDVTileSize( 32, 128, 128, 32, 128, 32, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, 1),
|
||||
FmhaBwdDQDKDVTileSize( 16, 192, 128, 16, 128, 16, 32, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
|
||||
# FmhaBwdDQDKDVTileSize( 16, 32, 128, 16, 128, 16, 32, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 1, 16),
|
||||
FmhaBwdDQDKDVTileSize( 16, 16, 128, 16, 128, 16, 16, 128, 128, 1, 1, 1, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 16, 2, 16),
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_dot_do_o_trait_{F_idx} =
|
||||
ck_tile::TileFmhaBwdOGradDotOTraits<{F_spad}, {F_dvpad}, {F_occupancy}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_pipeline_problem_{F_idx} = ck_tile::BlockFmhaBwdOGradDotOPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::OGradDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::DDataType,
|
||||
/* BlockSize = M0 = */ 64,
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
fmha_bwd_dot_do_o_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_{F_idx} =
|
||||
typename ck_tile::BlockFmhaBwdOGradDotO<fmha_bwd_dot_do_o_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_dot_do_o_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdOGradDotOKernel<fmha_bwd_dot_do_o_{F_idx}>;
|
||||
|
||||
using dot_do_o_trait_{F_idx} =
|
||||
fmha_bwd_dot_do_o_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_dot_do_o_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_dot_do_o_oneshot_<dot_do_o_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_dot_do_o_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_dot_do_o_get_name_<dot_do_o_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_dot_do_o_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdOGradDotOKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_spad : str # true/false
|
||||
F_dvpad : str #
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_occupancy : int
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_DOT_DO_O_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_dvpad = BOOL_MAP[self.F_dvpad],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_occupancy = self.F_occupancy)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_dot_do_o_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_bwd_convert_dq_trait_{F_idx} =
|
||||
ck_tile::TileFmhaBwdConvertQGradTraits<{F_spad}, {F_dpad}, {F_occupancy}>;
|
||||
|
||||
using fmha_bwd_convert_dq_pipeline_problem_{F_idx} =
|
||||
ck_tile::BlockFmhaBwdConvertQGradPipelineProblem<
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::AccDataType,
|
||||
typename FmhaBwdTypeConfig<fmha_dtype_{F_idx}>::QGradDataType,
|
||||
/* BlockSize = */ 256,
|
||||
{F_bm0},
|
||||
{F_bn0},
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
{F_deterministic},
|
||||
fmha_bwd_convert_dq_trait_{F_idx}>;
|
||||
|
||||
using fmha_bwd_convert_dq_{F_idx} =
|
||||
typename ck_tile::BlockFmhaBwdConvertQGrad<fmha_bwd_convert_dq_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_bwd_convert_dq_kernel_{F_idx} =
|
||||
ck_tile::FmhaBwdConvertQGradKernel<fmha_bwd_convert_dq_{F_idx}>;
|
||||
|
||||
using convert_dq_trait_{F_idx} = fmha_bwd_convert_dq_traits_<{F_hdim},
|
||||
{F_dtype},
|
||||
{F_mode},
|
||||
{F_spad},
|
||||
{F_dpad},
|
||||
{F_deterministic},
|
||||
{F_bn0}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template <>
|
||||
float fmha_bwd_convert_dq_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s, fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
template <>
|
||||
void fmha_bwd_convert_dq_oneshot_<convert_dq_trait_{F_idx}>(const ck_tile::stream_config& s,
|
||||
fmha_bwd_args a)
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
auto [kargs, grids] = fmha_bwd_convert_dq_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(
|
||||
ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
|
||||
template <>
|
||||
std::string fmha_bwd_convert_dq_get_name_<convert_dq_trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = fmha_bwd_convert_dq_kernel_{F_idx};
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdConvertQGradKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_spad : str # true/false
|
||||
F_dpad : str #
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_occupancy : int #
|
||||
F_deterministic : str #
|
||||
disabled : bool # sometimes this kernel is not used
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
return FMHA_BWD_KERNEL_HEADER + \
|
||||
FMHA_BWD_CONVERT_DQ_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = BWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_bm0,
|
||||
F_bn0 = self.F_bn0,
|
||||
F_spad = BOOL_MAP[self.F_spad],
|
||||
F_dpad = BOOL_MAP[self.F_dpad],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_occupancy = self.F_occupancy,
|
||||
F_deterministic = BOOL_MAP[self.F_deterministic])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f"fmha_bwd_convert_dq_d{self.F_hdim}_{self.F_dtype}_b{self.F_bm0}x{self.F_bn0}_{self.F_mode}_o{self.F_occupancy}"
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
if self.F_deterministic == 't' : n += '_deterministic'
|
||||
else: n += '_ndeterministic'
|
||||
return n
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FmhaBwdApiTrait:
|
||||
idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
# sync with fmha_bwd_traits<>, to generate fallback calls
|
||||
hdim : int
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
tile : FmhaBwdDQDKDVTileSize
|
||||
mask : str
|
||||
bias : str
|
||||
dbias : str
|
||||
dropout : str
|
||||
spad1d : str # spad for 1d kernels (dot/convert)
|
||||
dpad : Literal[0, 1, 8]
|
||||
dvpad : Literal[0, 1, 8]
|
||||
deterministic : str
|
||||
mask_impl : str
|
||||
tr_load : str
|
||||
|
||||
@property
|
||||
def bm0(self) -> int:
|
||||
return self.tile.F_bm0
|
||||
@property
|
||||
def bn0(self) -> int:
|
||||
return self.tile.F_bn0
|
||||
@property
|
||||
def bhdq(self) -> int:
|
||||
return self.tile.F_bhdq
|
||||
@property
|
||||
def bhdv(self) -> int:
|
||||
return self.tile.F_bhdv
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group':
|
||||
return 'true' # always support
|
||||
elif self.spad1d == 't':
|
||||
return f'a.seqlen_q % {M0_1D} != 0'
|
||||
else: # self.spad1d == 'f'
|
||||
return f'a.seqlen_q % {M0_1D} == 0'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 0: return f'a.hdim_q % {self.bhdq} == 0'
|
||||
else: return f'a.hdim_q % {self.dpad} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 0: return f'a.hdim_v % {self.bhdv} == 0'
|
||||
else: return f'a.hdim_v % {self.dvpad} == 0'
|
||||
|
||||
@property
|
||||
def extra_cond(self) -> str:
|
||||
if self.tr_load == 't' and self.tile.max_seq_q == 0 and self.tile.F_bn0 == 128:
|
||||
return "&& (a.seqlen_k <= 256)"
|
||||
else:
|
||||
return ""
|
||||
|
||||
@property
|
||||
def convert_dq_bn0(self) -> int:
|
||||
return self.tile.F_bn0 if self.deterministic == 't' else 0
|
||||
|
||||
@property
|
||||
def dot_do_o_kernel(self) -> FmhaBwdOGradDotOKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
F_dvpad = 't' if self.dvpad else 'f'
|
||||
return FmhaBwdOGradDotOKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_spad=self.spad1d,
|
||||
F_dvpad=F_dvpad, F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim))
|
||||
|
||||
@property
|
||||
def dq_dk_dv_kernel(self) -> FmhaBwdDQDKDVKernel:
|
||||
return FmhaBwdDQDKDVKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype, F_tile=self.tile,
|
||||
F_dpad=self.dpad, F_dvpad=self.dvpad, F_bias=self.bias, F_dbias=self.dbias, F_dropout=self.dropout,
|
||||
F_mask=self.mask, F_mode=self.mode, F_deterministic=self.deterministic, mask_impl=self.mask_impl, F_trload=self.tr_load)
|
||||
|
||||
@property
|
||||
def convert_dq_kernel(self) -> FmhaBwdConvertQGradKernel:
|
||||
# TODO: we don't support tuning yet, so pick up one value for pad/occupancy
|
||||
# support this in future
|
||||
def get_occupancy(dtype, hdim):
|
||||
return 2
|
||||
|
||||
F_dpad = 't' if self.dpad else 'f'
|
||||
return FmhaBwdConvertQGradKernel(F_idx=self.idx, F_hdim=self.hdim, F_dtype=self.dtype,
|
||||
F_bm0=M0_1D, F_bn0=self.convert_dq_bn0, F_spad=self.spad1d, F_dpad=F_dpad,
|
||||
F_mode=self.mode, F_occupancy=get_occupancy(self.dtype, self.hdim),
|
||||
F_deterministic=self.deterministic, disabled=self.tile.max_seq_q != 0)
|
||||
|
||||
class FmhaBwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.dq_dk_dv_pool = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
|
||||
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_dq_dk_dv_traits(self, trait : FmhaBwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
self.dq_dk_dv_pool[trait.tr_load][trait.tile.max_seq_q][trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@staticmethod
|
||||
def if_(i: int) -> str:
|
||||
return 'if' if i == 0 else 'else if'
|
||||
|
||||
def _api_innders(self, traits: List[FmhaBwdApiTrait]) -> str:
|
||||
inners = ""
|
||||
i = 0
|
||||
for trait in traits:
|
||||
inners += FMHA_BWD_API_INNER_DISPATCH.format(F_if=self.if_(i), F_mode=MODE_MAP[trait.mode],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_mask=get_mask_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias],
|
||||
F_bias=BIAS_MAP[trait.bias], F_dbias=BOOL_MAP[trait.dbias], F_dropout_check=DROPOUT_CHECK_MAP[trait.dropout], F_dropout=DROPOUT_MAP[trait.dropout],
|
||||
F_scheck=trait.scheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_hdim=trait.hdim, F_dtype=BWD_DTYPE_MAP[trait.dtype],
|
||||
F_spad1d=BOOL_MAP[trait.spad1d], F_dpad=trait.dpad, F_dvpad=trait.dvpad,
|
||||
F_deterministic=BOOL_MAP[trait.deterministic], F_trload=BOOL_MAP[trait.tr_load], F_maxq=trait.tile.max_seq_q,
|
||||
F_convert_dq_enabled=BOOL_MAP[not trait.convert_dq_kernel.disabled], F_bn0=trait.tile.F_bn0, F_cond_extra=trait.extra_cond,
|
||||
F_convert_dq_bn0=trait.convert_dq_bn0)
|
||||
i += 1
|
||||
return inners
|
||||
|
||||
@staticmethod
|
||||
def trload_sort_key(tf):
|
||||
return 0 if tf == 't' else 1 # sort 't' before 'f'
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_sort_key(max_seq_q):
|
||||
return max_seq_q if max_seq_q != 0 else 1000000 # sort 0 to the end
|
||||
|
||||
@staticmethod
|
||||
def max_seq_q_cond(max_seq_q: int) -> str:
|
||||
if max_seq_q == 0:
|
||||
return 'true /* no seqlen_q limit */'
|
||||
else:
|
||||
return f'a.seqlen_q <= {max_seq_q}'
|
||||
|
||||
@staticmethod
|
||||
def dtype_cond(dtype: str) -> str:
|
||||
return f't.data_type.compare("{dtype}") == 0'
|
||||
|
||||
@staticmethod
|
||||
def hdim_cond(hdim: int) -> str:
|
||||
return f't.hdim_q <= {hdim} && t.hdim_v <= {hdim}'
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {
|
||||
"t": "has_load_tr",
|
||||
"f": "true /* no trload requirement */"
|
||||
}
|
||||
per_tr_load = ''
|
||||
for tr_load in sorted(self.dq_dk_dv_pool.keys(), key=self.trload_sort_key):
|
||||
per_max_seq_q = ''
|
||||
for max_seq_q in sorted(self.dq_dk_dv_pool[tr_load].keys(), key=self.max_seq_q_sort_key):
|
||||
per_dtypes = ''
|
||||
for j, dtype in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q]):
|
||||
per_hdim_case = ''
|
||||
for k, hdim in enumerate(self.dq_dk_dv_pool[tr_load][max_seq_q][dtype]):
|
||||
traits = self.dq_dk_dv_pool[tr_load][max_seq_q][dtype][hdim]
|
||||
inners = self._api_innders(traits)
|
||||
per_hdim_case += FMHA_BWD_API_COND_STATEMENT(if_=k, F_cond=self.hdim_cond(hdim), F_body=inners)
|
||||
per_dtypes += FMHA_BWD_API_COND_STATEMENT(if_=j, F_cond=self.dtype_cond(dtype), F_body=per_hdim_case)
|
||||
per_max_seq_q += FMHA_BWD_API_COND_STATEMENT(F_cond=self.max_seq_q_cond(max_seq_q), F_body=per_dtypes)
|
||||
per_tr_load += FMHA_BWD_API_COND_STATEMENT(F_cond=tr_load_cond_map[tr_load], F_body=per_max_seq_q, indent=4)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += ' (void)t ; (void)s ; (void)a; (void)has_load_tr;'
|
||||
result = FMHA_BWD_KERNEL_HEADER + FMHA_BWD_API.format(F_dispatch = per_tr_load)
|
||||
return result.replace('\n\n', '\n')
|
||||
|
||||
def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[FmhaBwdApiPool, List[FmhaBwdOGradDotOKernel], List[FmhaBwdDQDKDVKernel], List[FmhaBwdConvertQGradKernel]]:
|
||||
if filter_list == '':
|
||||
filter_list = '*@*@*'
|
||||
filters = filter_list.split('@')
|
||||
filters.extend(['*'] * (3 - len(filters)))
|
||||
filter_dot_do_o = filters[0]
|
||||
filter_convert_dq = filters[1]
|
||||
filter_dq_dk_dv = filters[2]
|
||||
|
||||
# use dict as ordered set
|
||||
gen_dot_do_o: Dict[FmhaBwdOGradDotOKernel, Literal[True]] = {}
|
||||
gen_dq_dk_dv: Dict[FmhaBwdDQDKDVKernel, Literal[True]] = {}
|
||||
gen_convert_dq: Dict[FmhaBwdConvertQGradKernel, Literal[True]] = {}
|
||||
api_pool = FmhaBwdApiPool(mask_impl)
|
||||
|
||||
for dtype, tr_load in itertools.product(BWD_DTYPE_MAP.keys(), ["t", "f"]):
|
||||
tiles: Any = get_dq_dk_dv_tiles(dtype, tr_load)
|
||||
dpad_options = itertools.product(*([[0, 8, 1]] * 2))
|
||||
tf = ["t", "f"]
|
||||
for tile, mode, mask, bias, dbias, dropout, spad1d, (dpad, dvpad), deterministic in itertools.product(
|
||||
tiles, MODE_MAP.keys(), get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), tf, DROPOUT_MAP.keys(), tf, dpad_options, tf):
|
||||
assert isinstance(tile, FmhaBwdDQDKDVTileSize), "tile must be FmhaBwdDQDKDVTileSize"
|
||||
hdim = tile.F_bhdq
|
||||
if (mode == "group") and (spad1d == "f"):
|
||||
continue
|
||||
if (mode == "group" or ('no' not in mask)) and tile.max_seq_q != 0:
|
||||
continue
|
||||
if ((bias == "no" or bias == "alibi") and dbias == "t"):
|
||||
continue
|
||||
if ("wg32" in dropout):
|
||||
continue
|
||||
if tr_load == "t":
|
||||
continue # tr_load cannot work with dpad or dvpad
|
||||
else: # tr_load == "f"
|
||||
# do not generate instance with only 1 of dpad/dvpad being 8
|
||||
if dpad != dvpad and dpad == 8:
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
t = FmhaBwdApiTrait(idx=0, hdim=hdim, dtype=dtype, mode=mode,tile=tile,mask=mask, bias=bias, dbias=dbias, dropout=dropout, spad1d=spad1d, dpad=dpad, dvpad=dvpad, deterministic=deterministic, mask_impl=mask_impl, tr_load=tr_load)
|
||||
|
||||
if not fnmatch.fnmatch(t.dot_do_o_kernel.name, filter_dot_do_o):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.dq_dk_dv_kernel.name, filter_dq_dk_dv):
|
||||
continue
|
||||
if not fnmatch.fnmatch(t.convert_dq_kernel.name, filter_convert_dq):
|
||||
continue
|
||||
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 3:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'alibi']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= bias in ['no', 'bias']
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
cond &= dpad == dvpad
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_bwd) integration
|
||||
elif receipt == 300:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "batch"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter (mha_varlen_bwd) integration
|
||||
elif receipt == 400:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= dropout in ['no', 'dropout_wg32', 'dropout_wg16']
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_bwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only, all variations
|
||||
if receipt == 800:
|
||||
cond = dtype == 'fp32'
|
||||
cond &= dpad == dvpad
|
||||
if not cond:
|
||||
continue
|
||||
# fp32 only, minimal set of parameters
|
||||
elif receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
cond &= hdim in [64, 128]
|
||||
cond &= dpad == dvpad
|
||||
cond &= mode == 'batch'
|
||||
cond &= bias == 'no'
|
||||
cond &= dropout == 'no'
|
||||
cond &= mask == 's_no'
|
||||
cond &= deterministic == "f"
|
||||
if not cond:
|
||||
continue
|
||||
else:
|
||||
# Don't build fp32 by default
|
||||
if dtype == 'fp32':
|
||||
continue
|
||||
|
||||
gen_dot_do_o[t.dot_do_o_kernel] = True
|
||||
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
|
||||
if not t.convert_dq_kernel.disabled:
|
||||
gen_convert_dq[t.convert_dq_kernel] = True
|
||||
api_pool.register_dq_dk_dv_traits(t)
|
||||
|
||||
return api_pool, list(gen_dot_do_o.keys()), list(gen_dq_dk_dv.keys()), list(gen_convert_dq.keys())
|
||||
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(filter_list, receipt, mask_impl, optdim_list)
|
||||
update_file(output_dir / FMHA_BWD_API_FILENAME, api_pool.api)
|
||||
for k in kernels_dot_do_o:
|
||||
update_file(output_dir / k.filename, k.template)
|
||||
for k in kernels_convert_dq:
|
||||
update_file(output_dir / k.filename, k.template)
|
||||
for k in kernels_dq_dk_dv:
|
||||
update_file(output_dir / k.filename, k.template)
|
||||
|
||||
|
||||
def list_blobs(file_path: Path, filter_list: str, receipt, optdim_list, mask_impl) -> None:
|
||||
_, kernels_dot_do_o, kernels_dq_dk_dv, kernels_convert_dq = get_bwd_blobs(
|
||||
filter_list, receipt, mask_impl, optdim_list
|
||||
)
|
||||
with file_path.open("a") as f:
|
||||
for k in kernels_dot_do_o:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
for k in kernels_dq_dk_dv:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
for k in kernels_convert_dq:
|
||||
f.write(str(file_path.parent / GEN_DIR / k.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_BWD_API_FILENAME) + "\n")
|
||||
@@ -1,783 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass, field
|
||||
import fnmatch
|
||||
import itertools
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
from codegen.utils import update_file
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
48 : 48,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
192: 192,
|
||||
256: 256
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse},
|
||||
{F_dropout},
|
||||
{F_squant},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::RandValOutputDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
#include <cstdio>
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
namespace {{
|
||||
bool get_num_cus(unsigned& num_cus) {{
|
||||
int device;
|
||||
auto status = hipGetDevice(&device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device");
|
||||
return false;
|
||||
}}
|
||||
|
||||
hipDeviceProp_t props{{}};
|
||||
status = hipGetDeviceProperties(&props, device);
|
||||
if(status != hipSuccess) {{
|
||||
fprintf(stderr, "failed to get device properties");
|
||||
return false;
|
||||
}}
|
||||
|
||||
num_cus = props.multiProcessorCount;
|
||||
return true;
|
||||
}}
|
||||
|
||||
unsigned get_num_thread_blocks(unsigned batch, unsigned nheads, unsigned max_seqlen_q, unsigned kM0) {{
|
||||
const unsigned num_m_blocks = (max_seqlen_q + kM0 - 1) / kM0;
|
||||
const unsigned num_n_blocks = 1; // we assume that num_n_blocks is always 1
|
||||
|
||||
return batch * nheads * num_m_blocks * num_n_blocks;
|
||||
}}
|
||||
}} // namespace
|
||||
|
||||
float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
|
||||
[[maybe_unused]] const float min_cu_util_rate = 0.8; // minimum CU utilization rate
|
||||
|
||||
unsigned num_cus;
|
||||
if (!get_num_cus(num_cus)) {{
|
||||
return r;
|
||||
}}
|
||||
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class CppConstraint:
|
||||
bool_expr: str = None
|
||||
|
||||
def __str__(self):
|
||||
if self.bool_expr is None:
|
||||
return 'true'
|
||||
else:
|
||||
return f'{self.bool_expr}'
|
||||
|
||||
def __and__(self, other):
|
||||
return CppConstraint(f'({str(self)}) && ({str(other)})')
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
dropout : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
tr_load : str
|
||||
constraint : CppConstraint
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.dropout}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag in ['qr_async', 'qr_async_trload']:
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
def seqtune(self, max_bm0 : int) -> str:
|
||||
if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/'
|
||||
else:
|
||||
return f'a.seqlen_q <= {self.bm0}'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'(a.cu_seqlen_kv_ptr != nullptr) || (a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0)'
|
||||
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'(a.cu_seqlen_kv_ptr == nullptr) && (a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0)'
|
||||
elif self.pipeline_tag == 'qr_async_trload':
|
||||
if self.skpad == 't' : return 'true'
|
||||
else: return 'true'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_dropout : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
F_trload : str # true/false
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_dropout == 't' : n += '_dropout'
|
||||
else: n += '_ndropout'
|
||||
|
||||
if self.F_skip == 't' : n += '_skip'
|
||||
else: n += '_nskip'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
if self.F_trload == 't' : n += '_trload'
|
||||
else: n += '_ntrload'
|
||||
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
hdim = trait.hdim, trait.bn1
|
||||
if hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
tr_load_cond_map = {
|
||||
"t": "has_load_tr",
|
||||
"f": "true"
|
||||
}
|
||||
|
||||
per_tr_load =str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load]
|
||||
max_bm0 = max((t.bm0 for t in traits), default=0)
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_tr_load += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_dropout = BOOL_MAP[self.F_pipeline.F_dropout],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload = BOOL_MAP[self.F_pipeline.F_trload])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
dropout=self.F_pipeline.F_dropout,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
|
||||
|
||||
class KernelComponentFactory:
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp32':
|
||||
return {
|
||||
# bm0, bn0, bk0, bn1, bk1,
|
||||
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
|
||||
}
|
||||
elif dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'fp8bf16':
|
||||
return {
|
||||
(64,64 ) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
(256,256) : [FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
elif dtype == 'fp8fp32':
|
||||
return {
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1)],
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
@staticmethod
|
||||
def get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
pipelines = []
|
||||
if dtype in ['fp32']:
|
||||
squant = 'f'
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
elif dtype in ['fp16', 'bf16']:
|
||||
squant = 'f'
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f":
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'fp8bf16', 'fp8fp32']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, squant, mask, bias in itertools.product(["f"], ["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'bf8']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
class CustomFactory(KernelComponentFactory):
|
||||
@staticmethod
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
result = KernelComponentFactory.get_hdim_tile_size_dict(dtype)
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
if (128, 128) in result.keys():
|
||||
result[(128, 128)].insert(0, FmhaFwdTileSize( 64, 128, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1, CppConstraint('get_num_blocks(128) < num_cus * min_cu_util_rate')))
|
||||
return result
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
factory = CustomFactory if os.environ.get('CK_TILE_FMHA_FWD_CUSTOM_FACTORY', '0') == '1' else KernelComponentFactory
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = factory.get_hdim_tile_size_dict(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
|
||||
for tile, next_tile in zip(tiles, tiles[1:]):
|
||||
assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0'
|
||||
for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if (hdim, hdim_v) == (192, 128):
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
if dtype != 'fp32':
|
||||
if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)):
|
||||
# non qr_async_trload only support km0=128 tile size when hdim is not 128
|
||||
# non qr_async only support kn0=128 tile size when hdim is 128
|
||||
continue
|
||||
if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])):
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
cond &= pipeline.F_logits == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16', 'fp8bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if dtype == 'fp8bf16':
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
elif receipt == 888:
|
||||
cond = dtype in ['fp8', 'fp8bf16', 'fp8fp32']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= hdim == 128
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only, all variations
|
||||
if receipt == 800:
|
||||
cond = dtype == 'fp32'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
cond &= pipeline.F_logits == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# fp32 only, minimal set of parameters
|
||||
elif receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
cond &= hdim in [48, 128]
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_bias == 'no'
|
||||
cond &= pipeline.F_lse == 'f'
|
||||
cond &= pipeline.F_dropout == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
cond &= pipeline.F_logits == 'f'
|
||||
cond &= pipeline.F_mask == 's_no'
|
||||
if not cond:
|
||||
continue
|
||||
else:
|
||||
# Don't build fp32 by default
|
||||
if dtype == 'fp32':
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
@@ -1,376 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
FmhaFwdApiTrait,
|
||||
DTYPE_BITS,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdAppendKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
{F_bs},
|
||||
{F_bsk},
|
||||
{F_bd},
|
||||
{F_bdv},
|
||||
{F_vlayout},
|
||||
{F_rope},
|
||||
{F_pagedkv},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = ck_tile::BlockFmhaFwdAppendKVPipeline<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_kernel_{F_idx} = ck_tile::FmhaFwdAppendKVKernel<fmha_pipeline_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout},
|
||||
{F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_appendkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_appendkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_appendkv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_FILENAME="fmha_fwd_appendkv_api.cpp"
|
||||
FMHA_FWD_APPENDKV_API="""
|
||||
float fmha_fwd_appendkv(fmha_fwd_appendkv_traits t, fmha_fwd_appendkv_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_APPENDKV_API_INNER_DISPATCH=""" {F_if}((t.is_v_rowmajor == {F_vlayout}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && (t.rope_type == {F_rope_check}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv})) {{
|
||||
using trait_ = fmha_fwd_appendkv_traits_<{F_hdim}, {F_dtype}, {F_bs}, {F_bsk}, {F_bd}, {F_bdv}, {F_vlayout}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_rope}, {F_pagedkv}>;
|
||||
return fmha_fwd_appendkv_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVApiTrait:
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
bs : int # tile size along q seqlen
|
||||
bsk : int # tile size along k seqlen
|
||||
bd : int # tile size along qk gemm unroll
|
||||
bdv : int # tile size along kv gemm unroll
|
||||
vlayout : str
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
rope : str # key from ROPE_MAP
|
||||
pagedkv : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.bs}-{self.bsk}-{self.bd}-{self.bdv}-{self.vlayout}-'+\
|
||||
f'{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.rope}-{self.pagedkv}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bs} != 0*/'
|
||||
else : return f'a.seqlen_q % {self.bs} == 0'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
# we do not check all the values in a.seqlen_k_ptr
|
||||
return 'true'
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {self.bd} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {self.bd} == 0'
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {self.bdv} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {self.bdv} == 0'
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVPipeline:
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_rope : str # key from ROPE_MAP
|
||||
F_pagedkv : str # t/f
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
if self.F_rope != 'no': n += f'_{self.F_rope}'
|
||||
if self.F_pagedkv == 't': n += '_pagedkv'
|
||||
return n
|
||||
|
||||
class FmhaFwdAppendKVApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_APPENDKV_API_INNER_DISPATCH.format(F_if=if_k, F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck, F_rope_check=ROPE_CHECK_MAP[trait.rope],
|
||||
F_pagedkv=BOOL_MAP[trait.pagedkv], F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_rope=ROPE_MAP[trait.rope], F_bs=trait.bs, F_bsk=trait.bsk, F_bd=trait.bd, F_bdv=trait.bdv, F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVTileSize:
|
||||
F_bs : int # tile size along q seqlen
|
||||
F_bsk : int # tile size along k seqlen
|
||||
F_bd : int # tile size along qk gemm unroll
|
||||
F_bdv : int # tile size along kv gemm unroll
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bs}x{self.F_bsk}x{self.F_bd}x{self.F_bdv}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdAppendKVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_tile : FmhaFwdAppendKVTileSize
|
||||
F_pipeline : FmhaFwdAppendKVPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_APPENDKV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bs = self.F_tile.F_bs,
|
||||
F_bsk = self.F_tile.F_bsk,
|
||||
F_bd = self.F_tile.F_bd,
|
||||
F_bdv = self.F_tile.F_bdv,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_rope = ROPE_MAP[self.F_pipeline.F_rope],
|
||||
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy = self.F_tile.F_occupancy)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_appendkv_d{self.F_hdim}_{self.F_dtype}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdAppendKVApiTrait:
|
||||
return FmhaFwdAppendKVApiTrait(
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
bs=self.F_tile.F_bs,
|
||||
bsk=self.F_tile.F_bsk,
|
||||
bd=self.F_tile.F_bd,
|
||||
bdv=self.F_tile.F_bdv,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
rope=self.F_pipeline.F_rope,
|
||||
pagedkv=self.F_pipeline.F_pagedkv)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdAppendKVTileSize(64, 64, 32, 32, -1),
|
||||
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdAppendKVTileSize(64, 64, 64, 64, -1),
|
||||
'128' : FmhaFwdAppendKVTileSize(64, 64, 128, 128, -1),
|
||||
'256' : FmhaFwdAppendKVTileSize(64, 64, 256, 256, -1)
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdAppendKVApiPool, List[FmhaFwdAppendKVKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdAppendKVPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
# NOTICE: it will be very complicated if we consider all the hdim_q padding cases while
|
||||
# applying rotary embedding, so I just use 't' in inter/half pipelines
|
||||
for vlayout in ['row', 'col']:
|
||||
for pagedkv in ["t", "f"]:
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 'f', 'f', 'no', pagedkv))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'no', pagedkv))
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'inter', pagedkv))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'inter', pagedkv))
|
||||
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 'f', 't', 't', 'f', 'half', pagedkv))
|
||||
pipelines.append(FmhaFwdAppendKVPipeline(vlayout, 't', 't', 't', 't', 'half', pagedkv))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# rope/paged-kv is not supported
|
||||
pipelines.append(FmhaFwdAppendKVPipeline('col', 't', 't', 't', 't', 'no', 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdAppendKVApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_appendkv_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
for hdim_str in d.keys():
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
k = FmhaFwdAppendKVKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_kernel(kernel: FmhaFwdAppendKVKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_appendkv_api(api_pool : FmhaFwdAppendKVApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_APPENDKV_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_appendkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_appendkv_blobs(kernel_filter, receipt, mask_impl, optdim_list)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_APPENDKV_API_FILENAME) + "\n")
|
||||
@@ -1,885 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
from codegen.ops.fmha_fwd import (
|
||||
FmhaFwdTileSize,
|
||||
FmhaFwdApiTrait,
|
||||
FMHA_FWD_KERNEL_HEADER,
|
||||
FMHA_FWD_API_PER_DTYPE,
|
||||
FMHA_FWD_API_PER_HDIM_CASE,
|
||||
)
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
# 160: 160,
|
||||
256: 256
|
||||
}
|
||||
|
||||
FMHA_FWD_SPLITKV_PIPELINE_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaFwdSplitKVPipelineQRKSVS",
|
||||
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaFwdSplitKVPipelineNWarpSShuffleQRKSVS",
|
||||
}
|
||||
|
||||
FMHA_FWD_SPLITKV_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
namespace {{
|
||||
template <bool kHasUnevenSplits, bool kMergeNumHeadGroupsSeqLenQ = false>
|
||||
struct instance {{
|
||||
using fmha_block_tile = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape = ck_tile::TileFmhaShape<fmha_block_tile,
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait = ck_tile::TileFmhaFwdSplitKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
/*kHasBiasGrad=*/false,
|
||||
{F_lse},
|
||||
{F_squant},
|
||||
{F_pagedkv},
|
||||
kHasUnevenSplits,
|
||||
kMergeNumHeadGroupsSeqLenQ,
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_pipeline_problem = ck_tile::BlockFmhaFwdSplitKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
fmha_shape,
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait>;
|
||||
|
||||
using fmha_pipeline = {F_pipeline}<
|
||||
fmha_pipeline_problem>;
|
||||
|
||||
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
|
||||
/// store_tile_raw() data corruption issue
|
||||
using fmha_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
false, false>>;
|
||||
|
||||
using fmha_kernel =
|
||||
ck_tile::FmhaFwdSplitKVKernel<fmha_pipeline, fmha_epilogue>;
|
||||
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel;
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
}}
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad},
|
||||
{F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wtautological-compare"
|
||||
|
||||
namespace {{
|
||||
template <bool kHasUnevenSplits>
|
||||
void run_instance(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a) {{
|
||||
if constexpr ({F_hdim} == 128 && {F_bias} == ck_tile::BlockAttentionBiasEnum::NO_BIAS
|
||||
&& (std::is_same_v<{F_mask}, ck_tile::SimplifiedGenericAttentionMask<false>>
|
||||
|| std::is_same_v<{F_mask}, FmhaMasks::NoMask>)) {{
|
||||
if (a.max_seqlen_q == 1 && a.nhead_k < a.nhead_q) {{
|
||||
instance<kHasUnevenSplits, /*kMergeNumHeadGroupsSeqLenQ=*/true>::run(s, a);
|
||||
}} else {{
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
instance<kHasUnevenSplits>::run(s, a);
|
||||
}}
|
||||
}}
|
||||
}} // anonymous namespace
|
||||
|
||||
#pragma clang diagnostic pop
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if constexpr({F_mode} == false) {{ // batch mode
|
||||
// we don't check every seqlen_k values for kvcache
|
||||
if (a.seqlen_k_ptr != nullptr) {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
// make sure F_bn0 is divisible by F_bk1
|
||||
}} else if (a.seqlen_k % (a.num_splits * {F_bn0}) == 0) {{
|
||||
run_instance</*kHasUnevenSplits=*/false>(s, a);
|
||||
}} else {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
run_instance</*kHasUnevenSplits=*/true>(s, a);
|
||||
}}
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_get_name_<trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = instance<true>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
namespace {{
|
||||
template <ck_tile::index_t kLogMaxSplits>
|
||||
struct instance {{
|
||||
using fmha_trait = ck_tile::TileFmhaFwdSplitKVCombineTraits<{F_spad},
|
||||
{F_dvpad},
|
||||
{F_lse},
|
||||
{F_squant},
|
||||
kLogMaxSplits,
|
||||
{F_occupancy}>;
|
||||
|
||||
using fmha_pipeline_problem = ck_tile::BlockFmhaSplitKVCombinePipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
{F_hdim},
|
||||
{F_mode},
|
||||
{F_bn1},
|
||||
fmha_trait>;
|
||||
|
||||
using fmha_pipeline = ck_tile::BlockFmhaFwdSplitKVCombinePipeline<
|
||||
fmha_pipeline_problem>;
|
||||
|
||||
/// FIXME: use {F_spad}/{F_dvpad} as kPadM/kPadN parameters after solving
|
||||
/// store_tile_raw() data corruption issue
|
||||
using fmha_epilogue =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
false, false>>;
|
||||
|
||||
using fmha_kernel =
|
||||
ck_tile::FmhaFwdSplitKVCombineKernel<fmha_pipeline, fmha_epilogue>;
|
||||
|
||||
static void run(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel;
|
||||
auto [kargs, grids] = fmha_fwd_splitkv_combine_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs)(ck_tile::stream_config{{s.stream_id_}});
|
||||
}}
|
||||
}};
|
||||
}}
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bn1},
|
||||
{F_lse}, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
void fmha_fwd_splitkv_combine_oneshot_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if (a.num_splits <= 8) {{
|
||||
instance<3>::run(s, a);
|
||||
}} else if (a.num_splits <= 16) {{
|
||||
instance<4>::run(s, a);
|
||||
}} else if (a.num_splits <= 32) {{
|
||||
instance<5>::run(s, a);
|
||||
}} else if (a.num_splits <= 64) {{
|
||||
instance<6>::run(s, a);
|
||||
}} else if (a.num_splits <= 128) {{
|
||||
instance<7>::run(s, a);
|
||||
}}
|
||||
}}
|
||||
|
||||
template<>
|
||||
std::string fmha_fwd_splitkv_combine_get_name_<trait_{F_idx}>()
|
||||
{{
|
||||
using k_ = instance<6>::fmha_kernel; /// FIXME: choose real kernel type
|
||||
return k_::GetName();
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_FILENAME="fmha_fwd_splitkv_api.cpp"
|
||||
FMHA_FWD_SPLITKV_API="""
|
||||
#include <iostream>
|
||||
|
||||
template<typename fmha_fwd_splitkv_traits_, typename fmha_fwd_splitkv_combine_traits_>
|
||||
float fmha_fwd_splitkv_(const ck_tile::stream_config& s, fmha_fwd_splitkv_args a)
|
||||
{{
|
||||
if(s.log_level_ > 0)
|
||||
std::cout
|
||||
<< ", " << fmha_fwd_splitkv_get_name_<fmha_fwd_splitkv_traits_>()
|
||||
<< ", " << fmha_fwd_splitkv_combine_get_name_<fmha_fwd_splitkv_combine_traits_>()
|
||||
<< std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(s,
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_oneshot_<fmha_fwd_splitkv_traits_>(s_, a); }},
|
||||
[=](const ck_tile::stream_config& s_){{ fmha_fwd_splitkv_combine_oneshot_<fmha_fwd_splitkv_combine_traits_>(s_, a); }}
|
||||
);
|
||||
}}
|
||||
|
||||
float fmha_fwd_splitkv(fmha_fwd_splitkv_traits t, fmha_fwd_splitkv_args a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_SPLITKV_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_splitkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, true, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
// get combine kernel tile sizes
|
||||
using OaccDataType = typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType;
|
||||
constexpr ck_tile::index_t kM0 = ck_tile::BlockFmhaSplitKVCombinePipelineTileSizes<OaccDataType, /*F_bn1=*/32>::kM0;
|
||||
|
||||
// make sure we can reuse the padding flags in combine kernels
|
||||
static_assert({F_bm0} % kM0 == 0);
|
||||
static_assert({F_bn1} % 32 == 0);
|
||||
|
||||
if (t.has_lse) {{
|
||||
if constexpr (std::is_same_v<{F_dtype}, FmhaFwdFp8>) {{
|
||||
return -1;
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, true, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}} else {{
|
||||
using traits2_ = fmha_fwd_splitkv_combine_traits_<{F_hdim}, {F_dtype}, {F_mode}, /*F_bn1=*/32, false, {F_squant}, {F_spad}, {F_dvpad}>;
|
||||
|
||||
return fmha_fwd_splitkv_<traits_, traits2_>(s, a);
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
mask : str
|
||||
logits : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
pagedkv : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
|
||||
f'{self.dvpad}-{self.pagedkv}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr', 'qr_nwarp_sshuffle']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_squant : str #
|
||||
F_pagedkv : str # t/f
|
||||
F_mask : str # value from MASK_MAP
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
if self.F_pagedkv == 't' : n += '_pagedkv'
|
||||
else: n += '_npagedkv'
|
||||
return n
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVCombinePipeline:
|
||||
tag : str
|
||||
|
||||
F_spad : str # true/false
|
||||
F_dvpad : str #
|
||||
F_lse : str #
|
||||
F_squant : str #
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
return n
|
||||
|
||||
class FmhaFwdSplitKVApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdSplitKVApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_SPLITKV_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_squant=BOOL_MAP[trait.squant], F_pagedkv=BOOL_MAP[trait.pagedkv],
|
||||
F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_SPLITKV_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVCombineTileSize:
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bn1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdSplitKVPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_SPLITKV_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_FWD_SPLITKV_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_splitkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdSplitKVApiTrait:
|
||||
return FmhaFwdSplitKVApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdSplitKVCombineKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdSplitKVCombineTileSize
|
||||
F_pipeline : FmhaFwdSplitKVCombinePipeline
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_SPLITKV_COMBINE_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_mode = MODE_MAP[self.F_mode])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_splitkv_combine_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'64' : FmhaFwdTileSize(64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(64, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
# '160' : FmhaFwdTileSize(64, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
'32' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'96' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
# '160' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'128' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
'256' : FmhaFwdSplitKVCombineTileSize(32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, optdim_list) -> Tuple[FmhaFwdSplitKVApiPool, List[FmhaFwdSplitKVKernel]]:
|
||||
Pipeline = FmhaFwdSplitKVPipeline
|
||||
Kernel = FmhaFwdSplitKVKernel
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"]):
|
||||
pipelines.append(Pipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 'f', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, 't', squant, pagedkv, mask))
|
||||
|
||||
pipelines.append(Pipeline('qr', 'row', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('qr', 'col', 't', 't', 't', 't', logits, bias, 't', squant, pagedkv, mask))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(Pipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 't', squant, 'f', mask))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdSplitKVApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# Flash attention integration
|
||||
if receipt == 2:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16, bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= mode == 'batch'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd_splikv C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim_list) -> List[FmhaFwdSplitKVCombineKernel]:
|
||||
Pipeline = FmhaFwdSplitKVCombinePipeline
|
||||
Kernel = FmhaFwdSplitKVCombineKernel
|
||||
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdSplitKVCombinePipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(Pipeline('unused', spad, dvpad, lse, squant))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse kernels
|
||||
pipelines.append(Pipeline('unused', 'f', 'f', 'f', squant))
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_splitkv_combine_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
k = Kernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
if receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == "group"
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd_splikv C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
gen.append(k)
|
||||
|
||||
return gen
|
||||
|
||||
def write_single_kernel(kernel: Union[FmhaFwdSplitKVKernel, FmhaFwdSplitKVCombineKernel], autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_splitkv_api(api_pool : FmhaFwdSplitKVApiPool, autogen_dir: Path) -> None:
|
||||
file_path = autogen_dir / FMHA_FWD_SPLITKV_API_FILENAME
|
||||
file_path.write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (2 - len(filter_list)))
|
||||
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
api_pool, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list)
|
||||
for kernel in kernels:
|
||||
write_single_kernel(kernel, output_dir)
|
||||
write_fwd_splitkv_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, filter_list : str, receipt, optdim_list, mask_impl) -> None:
|
||||
filter_list = filter_list.split('@')
|
||||
filter_list.extend([''] * (2 - len(filter_list)))
|
||||
|
||||
with file_path.open('a') as f:
|
||||
kernels = get_fwd_splitkv_combine_blobs(filter_list[0], receipt, optdim_list)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
_, kernels = get_fwd_splitkv_blobs(filter_list[1], receipt, mask_impl, optdim_list)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_SPLITKV_API_FILENAME) + "\n")
|
||||
@@ -1,591 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
import fnmatch
|
||||
import itertools
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
"fp32": 32,
|
||||
"fp16": 16,
|
||||
"bf16": 16,
|
||||
"fp8" : 8,
|
||||
"bf8" : 8
|
||||
}
|
||||
|
||||
K0_MAX_SUBMAX_MAP = {
|
||||
32 : 32,
|
||||
64 : 64,
|
||||
96 : 128,
|
||||
128: 128,
|
||||
256: 256
|
||||
}
|
||||
|
||||
FMHA_FWD_PAGEDKV_PIPELINE_MAP = {
|
||||
"qr_pagedkv" : "ck_tile::BlockFmhaFwdPagedKVPipelineQRKSVS"
|
||||
}
|
||||
|
||||
FMHA_FWD_KERNEL_HEADER = """// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n
|
||||
// auto generated by generate.py
|
||||
#include "ck_tile/ops/fmha/block/variants.hpp"
|
||||
#include "fmha_fwd.hpp"
|
||||
"""
|
||||
|
||||
FMHA_FWD_KERNEL_BODY="""
|
||||
using fmha_dtype_{F_idx} = {F_dtype};
|
||||
|
||||
using fmha_block_tile_{F_idx} = ck_tile::sequence<{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}>;
|
||||
|
||||
using fmha_shape_{F_idx} = ck_tile::TileFmhaShape<fmha_block_tile_{F_idx},
|
||||
ck_tile::sequence<{F_rm0}, {F_rn0}, {F_rk0}>,
|
||||
ck_tile::sequence<{F_wm0}, {F_wn0}, {F_wk0}>,
|
||||
ck_tile::sequence<{F_rm1}, {F_rn1}, {F_rk1}>,
|
||||
ck_tile::sequence<{F_wm1}, {F_wn1}, {F_wk1}>,
|
||||
{F_vlayout}>;
|
||||
|
||||
using fmha_trait_{F_idx} = ck_tile::TileFmhaFwdPagedKVTraits<{F_spad},
|
||||
{F_skpad},
|
||||
{F_dpad},
|
||||
{F_dvpad},
|
||||
{F_logits},
|
||||
{F_bias},
|
||||
false,
|
||||
{F_lse}, //lse
|
||||
{F_pagedkv}, //pagedkv
|
||||
{F_squant},
|
||||
{F_occupancy},
|
||||
{F_skip}>;
|
||||
|
||||
using fmha_variant_{F_idx} = ck_tile::ComposedAttention<{F_logits} * ck_tile::LOGITS_SOFT_CAP, CK_TILE_FMHA_FWD_FAST_EXP2>;
|
||||
|
||||
using fmha_mask_{F_idx} = {F_mask};
|
||||
|
||||
using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaFwdPagedKVPipelineProblem<
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::QDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::KDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::VDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::SMPLComputeDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::BiasDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::LSEDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::PDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<fmha_dtype_{F_idx}>::ODataType,
|
||||
fmha_shape_{F_idx},
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
fmha_pipeline_problem_{F_idx}>;
|
||||
|
||||
using fmha_epilogue_{F_idx} =
|
||||
ck_tile::Default2DEpilogue<ck_tile::Default2DEpilogueProblem<typename FmhaFwdTypeConfig<{F_dtype}>::OaccDataType,
|
||||
typename FmhaFwdTypeConfig<{F_dtype}>::ODataType,
|
||||
{F_spad}, {F_dvpad}>>;
|
||||
|
||||
using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdPagedKVKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
template<>
|
||||
float fmha_fwd_pagedkv_<trait_{F_idx}>(const ck_tile::stream_config& s, fmha_fwd_pagedkv_args a)
|
||||
{{
|
||||
using k_ = fmha_kernel_{F_idx};
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << k_::GetName() << std::flush;
|
||||
auto [kargs, grids] = fmha_fwd_pagedkv_create_kargs_and_grids<k_>(a);
|
||||
const dim3 blocks = k_::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = k_::kBlockPerCu;
|
||||
return ck_tile::launch_kernel(s, ck_tile::make_kernel<kBlockPerCu>(k_{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_FILENAME="fmha_fwd_pagedkv_api.cpp"
|
||||
FMHA_FWD_API="""
|
||||
float fmha_fwd_pagedkv(fmha_fwd_pagedkv_traits& t, fmha_fwd_pagedkv_args& a, const ck_tile::stream_config& s){{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
"""
|
||||
FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <= {F_hdim_v}) {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.use_pagedkv == {F_pagedkv}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
using trait_ = fmha_fwd_pagedkv_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_pagedkv}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
return fmha_fwd_pagedkv_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdApiTrait:
|
||||
pipeline_tag : str
|
||||
# sync with fmha_fwd_traits<>, to generate fallback calls
|
||||
hdim : str
|
||||
dtype : str # data type
|
||||
mode : str # value from MODE_MAP
|
||||
bm0 : int # tile size along q seqlen (block size)
|
||||
bn0 : int # tile size along qk seqlen
|
||||
bk0 : int # tile size along qk gemm unroll
|
||||
bn1 : int # tile size along v head_dim
|
||||
bk1 : int # tile size along kv gemm unroll
|
||||
bk0max : int
|
||||
vlayout : str
|
||||
logits : str
|
||||
mask : str
|
||||
bias : str #
|
||||
lse : str #
|
||||
pagedkv : str
|
||||
squant : str #
|
||||
spad : str
|
||||
skpad : str
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f'{self.hdim}-{self.dtype}-{self.mode}-{self.bm0}-{self.bn0}-{self.bk0}-{self.bn0}-{self.bk1}-{self.bk0max}-'+\
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.pagedkv}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-{self.dvpad}-{self.skip}'
|
||||
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode skpad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.skpad == 't' : return f'a.seqlen_k == 0 || a.seqlen_k % {self.bn0} != 0'
|
||||
else : return f'a.seqlen_k != 0 && a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k_ptr != nullptr || a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k_ptr == nullptr && a.seqlen_k % {self.bn0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def dvcheck(self) -> str:
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr_pagedkv', 'qs']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
else: assert False
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdPipeline:
|
||||
tag : str
|
||||
|
||||
F_vlayout : str # row/col
|
||||
F_spad : str # true/false
|
||||
F_skpad : str #
|
||||
F_dpad : str #
|
||||
F_dvpad : str #
|
||||
F_logits : str # t/f
|
||||
F_bias : str # true/false
|
||||
F_lse : str #
|
||||
F_pagedkv : str #
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
def pad_name() -> str:
|
||||
n = ''
|
||||
if self.F_spad == 't': n += 's'
|
||||
if self.F_skpad == 't' : n += 'sk'
|
||||
if self.F_dpad == 't' : n += 'd'
|
||||
if self.F_dvpad == 't' : n += 'dv'
|
||||
if n != '' : n = 'p' + n
|
||||
return n
|
||||
pn = pad_name()
|
||||
n = f'{self.tag}_v{self.F_vlayout[0]}'
|
||||
if pn != '' : n += f'_{pn}'
|
||||
else: n += '_npad'
|
||||
|
||||
if self.F_logits == 't' : n += '_logits'
|
||||
else: n += '_nlogits'
|
||||
|
||||
if self.F_bias != 'no' : n += f'_{self.F_bias}'
|
||||
else: n += '_nbias'
|
||||
|
||||
if self.F_mask[0:2] == 's_':
|
||||
if self.F_mask == 's_mask': n += f'_mask'
|
||||
else: n += '_nmask'
|
||||
else:
|
||||
if self.F_mask != 'no' : n += f'_m{self.F_mask[0]}'
|
||||
else: n += '_nmask'
|
||||
|
||||
if self.F_lse == 't' : n += '_lse'
|
||||
else: n += '_nlse'
|
||||
|
||||
if self.F_skip == 't' : n += '_skip'
|
||||
else: n += '_nskip'
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
if self.F_pagedkv == 't' : n += '_pagedkv'
|
||||
else: n += '_npagedkv'
|
||||
|
||||
return n
|
||||
|
||||
class FmhaFwdApiPool:
|
||||
def __init__(self, mask_impl):
|
||||
self.pool = dict()
|
||||
self.mask_impl = mask_impl
|
||||
|
||||
def register_traits(self, trait : FmhaFwdApiTrait) -> None:
|
||||
# TODO: do we need to check duplication?
|
||||
if trait.dtype not in self.pool.keys():
|
||||
self.pool[trait.dtype] = dict()
|
||||
if trait.hdim not in self.pool[trait.dtype].keys():
|
||||
self.pool[trait.dtype][trait.hdim] = list()
|
||||
|
||||
self.pool[trait.dtype][trait.hdim].append(copy.copy(trait))
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, hdim in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][hdim]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_pagedkv=BOOL_MAP[trait.pagedkv], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=trait.bn1, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
F_bm0 : int # tile size along q seqlen (block size)
|
||||
F_bn0 : int # tile size along k seqlen
|
||||
F_bk0 : int # tile size along qk gemm unroll
|
||||
F_bn1 : int # tile size along v head_dim
|
||||
F_bk1 : int # tile size along kv gemm unroll
|
||||
F_bk0max : int # total length of K0, used for pipeline that need load Q at once (or repeately load Q as a whole tile)
|
||||
F_rm0 : int # number of warps for gemm0 along q seqlen
|
||||
F_rn0 : int # number of warps for gemm0 along k seqlen
|
||||
F_rk0 : int # number of warps for gemm0 along head dim q (not used)
|
||||
F_rm1 : int # number of warps for gemm1 along q seqlen
|
||||
F_rn1 : int # number of warps for gemm1 along head dim v
|
||||
F_rk1 : int # number of warps for gemm1 along k seqlen (not used)
|
||||
F_wm0 : int # gemm0 warp size along m
|
||||
F_wn0 : int # gemm0 warp size along n
|
||||
F_wk0 : int # gemm0 warp size along k
|
||||
F_wm1 : int # gemm1 warp size along m
|
||||
F_wn1 : int # gemm1 warp size along n
|
||||
F_wk1 : int # gemm1 warp size along k
|
||||
F_occupancy : int # occupancy, -1 will let pipeline decide the occupancy, other value will overwrite occupancy
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return f"b{self.F_bm0}x{self.F_bn0}x{self.F_bk0}x{self.F_bn1}x{self.F_bk1}x{self.F_bk0max}" +\
|
||||
f"_r{self.F_rm0}x{self.F_rn0}x{self.F_rk0}_r{self.F_rm1}x{self.F_rn1}x{self.F_rk1}" +\
|
||||
f"_w{self.F_wm0}x{self.F_wn0}x{self.F_wk0}_w{self.F_wm1}x{self.F_wn1}x{self.F_wk1}" +\
|
||||
("" if self.F_occupancy == -1 else f"_o{self.F_occupancy}")
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdKernel:
|
||||
F_idx : int # this is not a tunable, but a counter to differentiate symbol
|
||||
F_hdim : int # hdim
|
||||
F_dtype : str # data type
|
||||
F_mode : str # value from MODE_MAP
|
||||
F_tile : FmhaFwdTileSize
|
||||
F_pipeline : FmhaFwdPipeline
|
||||
mask_impl : str
|
||||
|
||||
@property
|
||||
def template(self) -> str:
|
||||
kernel_body = str()
|
||||
return FMHA_FWD_KERNEL_HEADER + \
|
||||
FMHA_FWD_KERNEL_BODY.format(
|
||||
F_idx = self.F_idx,
|
||||
F_hdim = self.F_hdim,
|
||||
F_dtype = FWD_DTYPE_MAP[self.F_dtype],
|
||||
F_bm0 = self.F_tile.F_bm0,
|
||||
F_bn0 = self.F_tile.F_bn0,
|
||||
F_bk0 = self.F_tile.F_bk0,
|
||||
F_bn1 = self.F_tile.F_bn1,
|
||||
F_bk1 = self.F_tile.F_bk1,
|
||||
F_bk0max = self.F_tile.F_bk0max,
|
||||
F_rm0 = self.F_tile.F_rm0,
|
||||
F_rn0 = self.F_tile.F_rn0,
|
||||
F_rk0 = self.F_tile.F_rk0,
|
||||
F_rm1 = self.F_tile.F_rm1,
|
||||
F_rn1 = self.F_tile.F_rn1,
|
||||
F_rk1 = self.F_tile.F_rk1,
|
||||
F_wm0 = self.F_tile.F_wm0,
|
||||
F_wn0 = self.F_tile.F_wn0,
|
||||
F_wk0 = self.F_tile.F_wk0,
|
||||
F_wm1 = self.F_tile.F_wm1,
|
||||
F_wn1 = self.F_tile.F_wn1,
|
||||
F_wk1 = self.F_tile.F_wk1,
|
||||
F_vlayout = LAYOUT_MAP[self.F_pipeline.F_vlayout],
|
||||
F_spad = BOOL_MAP[self.F_pipeline.F_spad],
|
||||
F_skpad = BOOL_MAP[self.F_pipeline.F_skpad],
|
||||
F_dpad = BOOL_MAP[self.F_pipeline.F_dpad],
|
||||
F_dvpad = BOOL_MAP[self.F_pipeline.F_dvpad],
|
||||
F_logits = BOOL_MAP[self.F_pipeline.F_logits],
|
||||
F_bias = BIAS_MAP[self.F_pipeline.F_bias],
|
||||
F_lse = BOOL_MAP[self.F_pipeline.F_lse],
|
||||
F_pagedkv = BOOL_MAP[self.F_pipeline.F_pagedkv],
|
||||
F_squant = BOOL_MAP[self.F_pipeline.F_squant],
|
||||
F_skip = BOOL_MAP[self.F_pipeline.F_skip],
|
||||
F_occupancy = self.F_tile.F_occupancy,
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = FMHA_FWD_PAGEDKV_PIPELINE_MAP[self.F_pipeline.tag])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
# TODO: we don't encode idx here
|
||||
return f"fmha_fwd_pagedkv_d{self.F_hdim}_{self.F_dtype}_{self.F_mode}_" + \
|
||||
self.F_tile.name + '_' + self.F_pipeline.name
|
||||
|
||||
@property
|
||||
def filename(self) -> str:
|
||||
return self.name + ".cpp"
|
||||
|
||||
def api_trait(self) -> FmhaFwdApiTrait:
|
||||
return FmhaFwdApiTrait(
|
||||
pipeline_tag=self.F_pipeline.tag,
|
||||
hdim=str(self.F_hdim),
|
||||
dtype=self.F_dtype,
|
||||
mode=self.F_mode,
|
||||
bm0=self.F_tile.F_bm0,
|
||||
bn0=self.F_tile.F_bn0,
|
||||
bk0=self.F_tile.F_bk0,
|
||||
bn1=self.F_tile.F_bn1,
|
||||
bk1=self.F_tile.F_bk1,
|
||||
bk0max=self.F_tile.F_bk0max,
|
||||
vlayout=self.F_pipeline.F_vlayout,
|
||||
mask=self.F_pipeline.F_mask,
|
||||
logits=self.F_pipeline.F_logits,
|
||||
bias=self.F_pipeline.F_bias,
|
||||
lse=self.F_pipeline.F_lse,
|
||||
pagedkv=self.F_pipeline.F_pagedkv,
|
||||
squant=self.F_pipeline.F_squant,
|
||||
spad=self.F_pipeline.F_spad,
|
||||
skpad=self.F_pipeline.F_skpad,
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip)
|
||||
|
||||
# TODO: design a more practical way to do it
|
||||
# this is current supported tile size per hdim
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
# '32' : FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
### '96' : FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '192' : FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
}
|
||||
elif dtype == 'fp8' or dtype == 'bf8':
|
||||
return {
|
||||
'64' : FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 2, 1, 1, 2, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'128' : FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
'256' : FmhaFwdTileSize(128, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 32, 32, 32, 32, 32, 32, -1),
|
||||
}
|
||||
else:
|
||||
return None
|
||||
|
||||
def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl) -> Tuple[FmhaFwdApiPool, List[FmhaFwdKernel]]:
|
||||
# TODO: we don't support tuning yet, so pick up one value for vlayout/pipeline/pad
|
||||
# support this in future
|
||||
def get_pipelines(dtype, hdim) -> List[FmhaFwdPipeline]:
|
||||
# this function will populate a list possible pipelines
|
||||
# TODO: the order of List matters! the later in this list will be also be checked later
|
||||
# TODO: currently for qr_pagedkv pipeline, let 't' padding to appear later!!
|
||||
# TODO: how to design this more generic?
|
||||
squant = 't' if dtype == 'fp8' else 'f'
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, pagedkv, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t"], ["f"]):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 'f', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', pagedkv, squant, mask, skip))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 'f', 'f', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_pagedkv', 'row', 't', 't', 'f', 'f', logits, bias, 'f', 't', squant, mask, 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
else:
|
||||
assert False
|
||||
return pipelines
|
||||
|
||||
gen = list()
|
||||
api_pool = FmhaFwdApiPool(mask_impl)
|
||||
|
||||
for dtype in FWD_DTYPE_MAP.keys():
|
||||
d = get_fmha_fwd_tile_dict_from_dtype(dtype)
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
# if pipeline.F_pagedkv == 'f':
|
||||
# continue
|
||||
if mode == "group":
|
||||
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
|
||||
# in group mode, spad/skpad must be true, since we can't predict if seqlen of current batch need pad or not
|
||||
continue
|
||||
if hdim == 192 and tile.F_bn1 == 128:
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_lse == 't' :
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
k = FmhaFwdKernel(F_idx=0,
|
||||
F_hdim=hdim,
|
||||
F_dtype=dtype,
|
||||
F_mode=mode,
|
||||
F_tile=tile,
|
||||
F_pipeline=pipeline,
|
||||
mask_impl=mask_impl)
|
||||
if kernel_filter != '':
|
||||
if not fnmatch.fnmatch(k.name, kernel_filter):
|
||||
continue
|
||||
if optdim_list != [-1]:
|
||||
if hdim not in optdim_list:
|
||||
continue
|
||||
# 2 - Flash attention integration
|
||||
if receipt in (2, 3):
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'alibi']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# PyTorch integration
|
||||
elif receipt == 4:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_bias in ['no', 'bias']
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
cond &= pipeline.F_skip == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_fwd) integration
|
||||
elif receipt == 100:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'batch'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# Aiter(mha_varlen_fwd) integration
|
||||
elif receipt == 200:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= mode == 'group'
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
# aiter::mha_fwd C++ api integration
|
||||
elif receipt == 600:
|
||||
cond = dtype in ['fp16', 'bf16']
|
||||
cond &= pipeline.F_vlayout == 'row'
|
||||
cond &= pipeline.F_squant == 'f'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
# fp32 only
|
||||
if receipt == 800 or receipt == 801:
|
||||
cond = dtype == 'fp32'
|
||||
if not cond:
|
||||
continue
|
||||
|
||||
api_pool.register_traits(k.api_trait())
|
||||
gen.append(k)
|
||||
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
write_single_fwd_kernel(kernel, output_dir)
|
||||
write_fwd_api(api_pool, output_dir)
|
||||
|
||||
def list_blobs(file_path : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
with file_path.open('a') as f:
|
||||
_, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
for kernel in kernels:
|
||||
f.write(str(file_path.parent / GEN_DIR / kernel.filename) + "\n")
|
||||
f.write(str(file_path.parent / GEN_DIR / FMHA_FWD_API_FILENAME) + "\n")
|
||||
@@ -1,21 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import os.path as path
|
||||
|
||||
|
||||
def update_file(file_path, content):
|
||||
"""Update the file at file_path with the given content if it differs from the existing content.
|
||||
|
||||
It avoids unnecessary touching of the file which triggers rebuilds
|
||||
"""
|
||||
|
||||
existing_content = ""
|
||||
if path.exists(file_path):
|
||||
with open(file_path, "r") as file:
|
||||
existing_content = file.read()
|
||||
if existing_content == content:
|
||||
return
|
||||
with open(file_path, "w") as file:
|
||||
file.write(content)
|
||||
@@ -1,132 +0,0 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import pkgutil
|
||||
import sys
|
||||
from typing import List, Optional
|
||||
|
||||
import codegen.ops
|
||||
from codegen.cmake_config import *
|
||||
|
||||
|
||||
class HandlerId(IntEnum):
|
||||
LIST_BLOBS = 0
|
||||
WRITE_BLOBS = 1
|
||||
|
||||
# inspect all modules under 'codegen.ops' and register API handlers
|
||||
ops = []
|
||||
for importer, module_name, _ in pkgutil.iter_modules(codegen.ops.__path__):
|
||||
full_module_name = '%s.%s' % (codegen.ops.__name__, module_name)
|
||||
ops.append(importer.find_spec(module_name).loader.load_module(module_name))
|
||||
unwanted_prefix = 'fmha_'
|
||||
handlers = dict(
|
||||
[(op.__name__[len(unwanted_prefix):] if op.__name__.startswith(unwanted_prefix) else op.__name__,
|
||||
(op.list_blobs, op.write_blobs)) for op in ops]
|
||||
)
|
||||
assert 0 < len(handlers)
|
||||
|
||||
def write_blobs(output_dir: Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
|
||||
if output_dir is None:
|
||||
output_dir = Path(__file__).parent
|
||||
else:
|
||||
output_dir = Path(output_dir) / GEN_DIR
|
||||
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.WRITE_BLOBS]
|
||||
handler(output_dir, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
# list all the files that will be generated
|
||||
def list_blobs(output_file : Optional[str], api_list : List[str], filters_list : List[str], optdim_list : List[int], receipt, mask_impl) -> None:
|
||||
assert output_file is not None
|
||||
file_path = Path(output_file)
|
||||
|
||||
# create an empty file / drop its contents if it exists
|
||||
open(file_path, "w").close()
|
||||
|
||||
for api, kernel_filter in zip(api_list, filters_list):
|
||||
handler = handlers[api][HandlerId.LIST_BLOBS]
|
||||
handler(file_path, kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK fmha kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--direction", # we keep 'direction' option for backward compatibility
|
||||
"-a",
|
||||
"--api",
|
||||
default='fwd',
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma."
|
||||
)
|
||||
parser.add_argument(
|
||||
"-o",
|
||||
"--output_dir",
|
||||
required=False,
|
||||
help="write all the blobs into a directory"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
required=False,
|
||||
help="list all the kernels to a file"
|
||||
)
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
default='',
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--mask",
|
||||
default="simplified",
|
||||
required=False,
|
||||
help="mask implementation, simplified/generic"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt. 0: generate only 8xhdim coverage\n" + \
|
||||
" 1: generate more instance to cover all hdim\n" + \
|
||||
" 2: Only generate instance for Flash attention integration\n" + \
|
||||
" 4: Only generate instance for PyTorch integration\n" + \
|
||||
" 100-199: Only generate instance for Aiter(mha_fwd) integration\n" + \
|
||||
" 200-299: Only generate instance for Aiter(mha_varlen_fwd) integration\n" + \
|
||||
" 300-399: Only generate instance for Aiter(mha_bwd) integration\n" + \
|
||||
" 400-499: Only generate instance for Aiter(mha_varlen_bwd) integration\n" + \
|
||||
" 600-699: Only generate instance for aiter::mha_fwd && aiter::mha_fwd_splitkv && aiter::mha_bwd C++ api integration"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optdim",
|
||||
default='-1',
|
||||
required=False,
|
||||
help="only optimize the hdim in the list. separated by comma. -1 is the default choice" + \
|
||||
"eg. --optdim=32,64,128,256"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
api_list = args.direction.split(',')
|
||||
filter_list = args.filter.split(',')
|
||||
filter_list.extend([''] * (len(api_list) - len(filter_list)))
|
||||
optdim_list = [int(hdim) for hdim in args.optdim.split(',')]
|
||||
|
||||
if args.list_blobs is not None:
|
||||
list_blobs(args.list_blobs, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
|
||||
else:
|
||||
write_blobs(args.output_dir, api_list, filter_list, optdim_list, int(args.receipt), mask_impl=args.mask)
|
||||
@@ -1,20 +0,0 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_bwd -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for perm in 0 1 ; do
|
||||
for hdim in 32 64 128 ; do
|
||||
|
||||
nhead=$((2048 / $hdim)) # follow fav2 setup
|
||||
$EXE -prec=$prec -b=32 -h=$nhead -d=$hdim -s=512 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=16 -h=$nhead -d=$hdim -s=1024 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=8 -h=$nhead -d=$hdim -s=2048 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=4 -h=$nhead -d=$hdim -s=4096 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=2 -h=$nhead -d=$hdim -s=8192 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kname=1 -v=$VALID ; sleep 3
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
@@ -1,42 +0,0 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_fwd_v3 -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for causal in 0 1 ; do
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for hdim in 128 ; do
|
||||
for perm in 0 ; do
|
||||
|
||||
$EXE -prec=$prec -b=32 -h=16 -s=512 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=16 -h=16 -s=1024 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=8 -h=16 -s=2048 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=4 -h=16 -s=4096 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=2 -h=16 -s=8192 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
$EXE -prec=$prec -b=1 -h=64 -s=16384 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=16 -h_k=1 -s=65536 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
$EXE -prec=$prec -b=1 -h=40 -s=37200 -d=$hdim -causal=$causal -iperm=$perm -operm=$perm -v=$VALID
|
||||
|
||||
done
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
# Padding benchmark comparisons for v3 (batch mode only)
|
||||
# ==== V3 Padding Benchmarks: batch mode (baseline vs low/med/high pad) ====
|
||||
prec="fp16"
|
||||
base_v3_args="-prec=$prec -b=4 -h=16 -d=128 -s=1024 -mask=0 -iperm=0 -operm=0 -v=$VALID"
|
||||
|
||||
# baseline (no pad)
|
||||
$EXE $base_v3_args
|
||||
|
||||
# low pad (≈90–95% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=1024,960,992,896 -kv_eff_lens=1024,960,992,896
|
||||
|
||||
# medium pad (≈60–75% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=896,768,512,640 -kv_eff_lens=896,768,512,640
|
||||
|
||||
# high pad (≈30–40% effective)
|
||||
$EXE $base_v3_args -q_eff_lens=512,384,256,320 -kv_eff_lens=512,384,256,320
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,244 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <optional>
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "ck_tile/core/container/span.hpp"
|
||||
|
||||
enum class mode_enum
|
||||
{
|
||||
batch = 0,
|
||||
group
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& stream, mode_enum mode)
|
||||
{
|
||||
return stream << (mode == mode_enum::batch ? "batch" : "group");
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
{
|
||||
using size_type = typename std::vector<T>::size_type;
|
||||
|
||||
os << "[";
|
||||
for(size_type idx = 0; idx < v.size(); ++idx)
|
||||
{
|
||||
if(0 < idx)
|
||||
{
|
||||
os << ", ";
|
||||
}
|
||||
os << v[idx];
|
||||
}
|
||||
return os << "]";
|
||||
}
|
||||
|
||||
std::vector<int32_t> to_seqstarts(ck_tile::span<const int32_t> seqlens)
|
||||
{
|
||||
std::vector<int32_t> seqstarts = {0};
|
||||
for(int32_t seqlen : seqlens)
|
||||
{
|
||||
seqstarts.push_back(seqstarts.back() + seqlen);
|
||||
}
|
||||
assert(seqstarts.size() == seqlens.size() + 1);
|
||||
return seqstarts;
|
||||
}
|
||||
|
||||
template <typename RandomEngine>
|
||||
std::vector<int32_t> generate_seqlens(mode_enum mode,
|
||||
unsigned count,
|
||||
int32_t seqlen_avg,
|
||||
int32_t seqlen_min, // if not negative, clamp min
|
||||
int32_t seqlen_max, // if not negative, clamp max
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
assert(0 < count);
|
||||
|
||||
seqlen_min = (0 < seqlen_min ? seqlen_min : 1);
|
||||
seqlen_max = (0 < seqlen_max ? seqlen_max : std::numeric_limits<int32_t>::max());
|
||||
assert(seqlen_min <= seqlen_max);
|
||||
|
||||
std::vector<int32_t> seqlens(count, std::clamp(seqlen_avg, seqlen_min, seqlen_max));
|
||||
|
||||
if(mode == mode_enum::group && 1 < count)
|
||||
{
|
||||
using size_type = std::vector<int32_t>::size_type;
|
||||
|
||||
std::uniform_int_distribution<size_type> idx_dist(0, count - 1);
|
||||
auto next_idx = std::bind(idx_dist, std::ref(random_engine));
|
||||
|
||||
std::uniform_int_distribution<size_type> step_dist(1, count - 1);
|
||||
auto next_step = std::bind(step_dist, std::ref(random_engine));
|
||||
|
||||
for(unsigned repeat = seqlen_avg * (count / 2); 0 < repeat; --repeat)
|
||||
{
|
||||
const size_type to_decrease = next_idx();
|
||||
// make sure each elements of seqlens is in range [seqlen_min, seqlen_max]
|
||||
if(seqlens[to_decrease] == seqlen_min)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
const size_type to_increase = (to_decrease + next_step()) % count;
|
||||
|
||||
if(seqlens[to_increase] >= seqlen_max)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
--seqlens[to_decrease];
|
||||
++seqlens[to_increase];
|
||||
}
|
||||
}
|
||||
|
||||
return seqlens;
|
||||
}
|
||||
|
||||
// return random integer generated uniformly in range [low, high]
|
||||
template <typename Int = int, typename RandomEngine>
|
||||
auto randint(Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>, Int>
|
||||
{
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
return dist(random_engine);
|
||||
}
|
||||
|
||||
// return random integers generated uniformly in range [low, high]
|
||||
template <typename Int, typename ForwardIterator, typename RandomEngine>
|
||||
auto randints(ForwardIterator first,
|
||||
ForwardIterator last,
|
||||
Int low,
|
||||
Int high,
|
||||
RandomEngine& random_engine) -> std::enable_if_t<std::is_integral_v<Int>>
|
||||
{
|
||||
std::uniform_int_distribution<Int> dist(low, high);
|
||||
|
||||
std::generate(first, last, [&] { return dist(random_engine); });
|
||||
}
|
||||
|
||||
/*
|
||||
* generate missing values in *_val randomly when the number of values is smaller than batch
|
||||
* example (assume batch=3)
|
||||
* q_val=1,2,3 k_val=4,5,6 -> OK
|
||||
* q_val=1,2,3 -> OK, k same as q
|
||||
* q_val=1,2 -> OK, q will rand remaining 1 element, k same as q
|
||||
* q_val=1,2 k_val=4,5 -> OK, q/k will rand remaining 1 element
|
||||
* q_val=1,2,3,4 -> OK, but ignore exceed one
|
||||
*
|
||||
* q_val=1,2 k_val=4,5,6 -> not OK, k must have same splits with q
|
||||
* q_val=1,2 k_val=4 -> not OK, k must have same splits with q
|
||||
*/
|
||||
template <typename RandomEngine>
|
||||
std::tuple<std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>,
|
||||
std::vector<ck_tile::index_t>>
|
||||
generate_missing_seqlens(mode_enum mode,
|
||||
ck_tile::index_t batch,
|
||||
const std::vector<ck_tile::index_t>& q_val,
|
||||
const std::vector<ck_tile::index_t>& k_val,
|
||||
const std::vector<ck_tile::index_t>& k_pad_val,
|
||||
ck_tile::index_t seqlen_k_min,
|
||||
bool need_append_kvcache,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
if(mode == mode_enum::batch)
|
||||
{
|
||||
ck_tile::index_t q = q_val[0];
|
||||
ck_tile::index_t k = k_val[0];
|
||||
|
||||
auto s_q = std::vector<ck_tile::index_t>(batch, q);
|
||||
auto s_k = [&] {
|
||||
const ck_tile::index_t seqlen_k_max = (k < 0 ? q : k);
|
||||
std::vector<ck_tile::index_t> seqlen_ks(batch, seqlen_k_max);
|
||||
|
||||
if(1 < batch && need_append_kvcache)
|
||||
{
|
||||
// to keep the original s_k value, we always use seqlen_k_max in first batch
|
||||
randints(std::next(seqlen_ks.begin()),
|
||||
seqlen_ks.end(),
|
||||
seqlen_k_min,
|
||||
seqlen_k_max,
|
||||
random_engine);
|
||||
return seqlen_ks;
|
||||
}
|
||||
|
||||
return seqlen_ks;
|
||||
}();
|
||||
auto s_kpad = std::vector<ck_tile::index_t>(batch, -1); // TODO: batch not support k_padding
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min if provided
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
else
|
||||
{
|
||||
std::vector<ck_tile::index_t> s_q;
|
||||
std::vector<ck_tile::index_t> s_k;
|
||||
std::vector<ck_tile::index_t> s_kpad;
|
||||
ck_tile::index_t idx = 0;
|
||||
for(; idx < std::min(static_cast<ck_tile::index_t>(q_val.size()), batch); ++idx)
|
||||
{
|
||||
ck_tile::index_t q = q_val[idx];
|
||||
ck_tile::index_t k =
|
||||
k_val[std::min(idx, static_cast<ck_tile::index_t>(k_val.size()) - 1)];
|
||||
ck_tile::index_t kp =
|
||||
k_pad_val.empty()
|
||||
? -1
|
||||
: k_pad_val[std::min(idx, static_cast<ck_tile::index_t>(k_pad_val.size()) - 1)];
|
||||
|
||||
s_q.push_back(q);
|
||||
s_k.push_back(k < 0 ? q : k);
|
||||
s_kpad.push_back(kp);
|
||||
|
||||
// s_k should be greater than or equal to seqlen_k_min
|
||||
if(s_k.back() < seqlen_k_min)
|
||||
{
|
||||
std::ostringstream msg;
|
||||
msg << __FILE__ << ":" << __LINE__ << ": seqlen_k (=" << s_k.back()
|
||||
<< ") is less than minimum seqlen_k (=" << seqlen_k_min << ")";
|
||||
throw std::runtime_error(msg.str());
|
||||
}
|
||||
}
|
||||
if(idx < batch)
|
||||
{
|
||||
auto rem_q =
|
||||
generate_seqlens(mode, batch - idx, s_q.back(), 1, s_q.back(), random_engine);
|
||||
auto rem_k = generate_seqlens(
|
||||
mode, batch - idx, s_k.back(), seqlen_k_min, s_kpad.back(), random_engine);
|
||||
|
||||
s_q.insert(s_q.end(), rem_q.begin(), rem_q.end());
|
||||
s_k.insert(s_k.end(), rem_k.begin(), rem_k.end());
|
||||
s_kpad.insert(s_kpad.end(), batch - idx, s_kpad.back());
|
||||
}
|
||||
return std::make_tuple(s_q, s_k, s_kpad);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RandomAccessIterator, typename Int, typename RandomEngine>
|
||||
std::enable_if_t<std::is_integral_v<Int>> iota_shuffle(RandomAccessIterator first,
|
||||
RandomAccessIterator last,
|
||||
Int value,
|
||||
RandomEngine& random_engine)
|
||||
{
|
||||
std::iota(first, last, value);
|
||||
std::shuffle(first, last, random_engine);
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
# fused multi-head attention
|
||||
|
||||
This folder contains example for fmha(fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.
|
||||
This folder contains example for unified attention (fused multi-head attention) using ck_tile tile-programming implementation. It is a good example to demonstrate the usage of tile-programming API, as well as illustrate the new approach to construct a kernel template and instantiate it(them) while keeping compile time fast.
|
||||
|
||||
## build
|
||||
```
|
||||
@@ -8,12 +8,12 @@ This folder contains example for fmha(fused multi-head attention) using ck_tile
|
||||
mkdir build && cd build
|
||||
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
|
||||
../script/cmake-ck-dev.sh ../ <arch>
|
||||
make tile_example_fmha_fwd -j
|
||||
make tile_example_unified_attention -j
|
||||
```
|
||||
This will result in an executable `build/bin/tile_example_fmha_fwd`
|
||||
This will result in an executable `build/bin/tile_example_unified_attention`
|
||||
|
||||
## kernel
|
||||
The kernel template is `fmha_fwd_kernel.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
|
||||
The kernel template is `unified_attention.hpp`, this is the grid-wise op in old ck_tile's terminology. We put it here purposely, to demonstrate one can construct a kernel by using various internal component from ck_tile. We may still have an implementation under ck_tile's include path (in the future) for the kernel template.
|
||||
|
||||
There are 2 template parameters for this kernel template.
|
||||
* `FmhaPipeline` is one of the block_tile_pipeline(under `include/ck_tile/tile_program/block_tile_pipeline`) which is a performance critical component. Indeed, we did a lot of optimization and trials to optimize the pipeline and may still workout more performance pipeline and update into that folder. People only need to replace this pipeline type and would be able to enjoy the benefit of different performant implementations (stay tuned for updated pipeline(s)).
|
||||
@@ -23,7 +23,7 @@ There are 2 template parameters for this kernel template.
|
||||
To speed up compile time, we instantiate the kernels into separate file. In this way we can benefit from parallel building from CMake/Make system. This is achieved by `generate.py` script. Besides, you can look into this script to learn how to instantiate a kernel instance step by step, which is described in `FMHA_FWD_KERNEL_BODY` variable.
|
||||
|
||||
## executable
|
||||
`tile_example_fmha_fwd` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_fmha_fwd -?` to list all the arguments. Below is an example of the output (may subject to change)
|
||||
`tile_example_unified_attention` is the example executable, implemented in `fmha_fwd.cpp`. You can type `./bin/tile_example_unified_attention -?` to list all the arguments. Below is an example of the output (may subject to change)
|
||||
```
|
||||
args:
|
||||
-v weather do CPU validation or not (default:1)
|
||||
@@ -88,14 +88,14 @@ args:
|
||||
-kv_eff_lens Batch-mode only: per-batch effective seqlen for KV (exclude PAD) (default:"")
|
||||
Comma-separated list of length 'b'. If empty, no override
|
||||
```
|
||||
Example 1: `./bin/tile_example_fmha_fwd -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_fmha_fwd -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
Example 1: `./bin/tile_example_unified_attention -b=1 -h=16 -s=16384 -d=128` will run a fmha case with batch=1, nhead=16, sequence length=16384, hdim=128, fp16 case.
|
||||
Example 2: `./bin/tile_example_unified_attention -b=1 -h=8 -s=16384 -d=64 -drop_prefs=1 -drop_seed=10 -drop_offset=1234` will run a fmha case with
|
||||
batch=1, nhead=8, sequence length=16384, hdim=64, drop_seed=0 (in GPU memory), drop_offset=1234 (in GPU memory) fp16 case
|
||||
|
||||
## Padding Examples
|
||||
Example 3 (Group mode with padding): `./bin/tile_example_fmha_fwd -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
|
||||
Example 3 (Group mode with padding): `./bin/tile_example_unified_attention -mode=1 -b=2 -h=8 -s=1024,2048 -s_k=1024,2048 -s_qpad=1536,3072 -s_kpad=1536,3072 -d=128` will run group mode with 2 batches having different sequence lengths (1024, 2048) but physically padded to (1536, 3072) respectively.
|
||||
|
||||
Example 4 (Batch mode with effective lengths): `./bin/tile_example_fmha_fwd -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
|
||||
Example 4 (Batch mode with effective lengths): `./bin/tile_example_unified_attention -mode=0 -b=2 -h=8 -s=2048 -s_k=2048 -d=128 -q_eff_lens=1024,1536 -kv_eff_lens=1024,1536` will run batch mode where all batches use 2048 as physical sequence length but have effective lengths of (1024, 1536) for Q and KV respectively.
|
||||
|
||||
## support features
|
||||
Currently we are still in rapid development stage, so more features/optimizations will be coming soon.
|
||||
@@ -154,6 +154,6 @@ We support sequence padding and variable-length processing in both batch and gro
|
||||
Both approaches optimize memory access patterns while supporting flexible sequence length requirements commonly found in transformer inference scenarios.
|
||||
|
||||
## FP8 experimental support
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_fmha_fwd`, on a gfx942 machine and ROCm 6.0+.
|
||||
As described in [this blog](https://blog.hippoml.com/8bit-hippoattention-up-to-3x-faster-compared-to-flashattentionv2-8f9def90b482), we have an experimental support for fp8 fmha kernels, you can evaluate the performance by setting the arg `-prec=fp8` to the `tile_example_unified_attention`, on a gfx942 machine and ROCm 6.0+.
|
||||
|
||||
Currently we only support `-vlayout=r`( `seqlen*hdim` for V matrix) for fp8 and fp8bf16 now. Full feature support will come later.
|
||||
|
Before Width: | Height: | Size: 29 KiB After Width: | Height: | Size: 29 KiB |
@@ -1,6 +1,6 @@
|
||||
#!/bin/sh
|
||||
# TODO: run this script from CK root or build directory
|
||||
EXE="$(find . -name tile_example_fmha_fwd -type f | head -n 1)"
|
||||
EXE="$(find . -name tile_example_unified_attention -type f | head -n 1)"
|
||||
VALID=0
|
||||
|
||||
for prec in "fp16" "bf16" ; do
|
||||
@@ -4,7 +4,6 @@
|
||||
include_directories(AFTER
|
||||
${CMAKE_CURRENT_LIST_DIR}
|
||||
)
|
||||
add_subdirectory(01_unified_attention)
|
||||
add_subdirectory(01_fmha)
|
||||
add_subdirectory(02_layernorm2d)
|
||||
add_subdirectory(03_gemm)
|
||||
@@ -30,4 +29,4 @@ add_subdirectory(36_pooling)
|
||||
add_subdirectory(38_block_scale_gemm)
|
||||
add_subdirectory(40_streamk_gemm)
|
||||
add_subdirectory(41_batched_contraction)
|
||||
|
||||
add_subdirectory(42_unified_attention)
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/unified_attention/pipeline/unified_attention_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_v3_pipeline.hpp"
|
||||
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
|
||||
#define ENABLE_ASM_MARKER 1
|
||||
#if ENABLE_ASM_MARKER
|
||||
@@ -34,217 +35,6 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename PipelineProblem, bool kIsMasking>
|
||||
struct CoreLoopScheduler;
|
||||
|
||||
template <typename PipelineProblem>
|
||||
struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/true>
|
||||
{
|
||||
template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
|
||||
CK_TILE_DEVICE static constexpr void schedule(ck_tile::number<WaveGroup>,
|
||||
ck_tile::number<Phase>)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
if constexpr(WaveGroup == 0)
|
||||
{
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
#endif
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
#endif
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PipelineProblem>
|
||||
struct CoreLoopScheduler<PipelineProblem, /*kIsMasking=*/false>
|
||||
{
|
||||
template <ck_tile::index_t WaveGroup, ck_tile::index_t Phase>
|
||||
CK_TILE_DEVICE static constexpr void schedule(ck_tile::number<WaveGroup>,
|
||||
ck_tile::number<Phase>)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
if constexpr(WaveGroup == 0)
|
||||
{
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
#endif
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(Phase == 0)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 1)
|
||||
{
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 2, 0); // TRANS
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
});
|
||||
}
|
||||
else if constexpr(Phase == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 2, 0); // VALU
|
||||
__builtin_amdgcn_sched_group_barrier(0x004, 4, 0); // SALU
|
||||
}
|
||||
else if constexpr(Phase == 3)
|
||||
{
|
||||
#if !CK_TILE_DISABLE_PACKED_FP32
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
#endif
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
__builtin_amdgcn_sched_group_barrier(0x002, 4, 0); // VALU
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
CK_TILE_DEVICE float fma_impl_vsv(float a, float b, float c)
|
||||
{
|
||||
#if CK_TILE_DISABLE_PACKED_FP32
|
||||
return a * b + c;
|
||||
#else
|
||||
float result;
|
||||
asm volatile("v_fma_f32 %[result], %[a], %[b], %[c]"
|
||||
: [result] "=v"(result)
|
||||
: [a] "v"(a), [b] "s"(b), [c] "v"(c));
|
||||
return result;
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float add_impl_vv(float lhs, float rhs)
|
||||
{
|
||||
float result;
|
||||
asm volatile("v_add_f32_e32 %[result], %[lhs], %[rhs]"
|
||||
: [result] "=v"(result)
|
||||
: [lhs] "v"(lhs), [rhs] "v"(rhs));
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE float mul_impl_vv(float lhs, float rhs)
|
||||
{
|
||||
float result;
|
||||
asm volatile("v_mul_f32_e32 %[result], %[lhs], %[rhs]"
|
||||
: [result] "=v"(result)
|
||||
: [lhs] "v"(lhs), [rhs] "v"(rhs));
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp16x2_t cvt_pk_fp16_f32(float a, float b)
|
||||
{
|
||||
fp16x2_t result;
|
||||
asm volatile("v_cvt_pk_f16_f32 %[result], %[a], %[b]"
|
||||
: [result] "=v"(result)
|
||||
: [a] "v"(a), [b] "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE bf16x2_t cvt_pk_bf16_f32(float a, float b)
|
||||
{
|
||||
bf16x2_t result;
|
||||
asm volatile("v_cvt_pk_bf16_f32 %[result], %[a], %[b]"
|
||||
: [result] "=v"(result)
|
||||
: [a] "v"(a), [b] "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE fp32x2_t pk_mul_f32(fp32x2_t lhs, fp32x2_t rhs)
|
||||
{
|
||||
fp32x2_t result;
|
||||
asm volatile("v_pk_mul_f32 %[result], %[lhs], %[rhs]"
|
||||
: [result] "=v"(result)
|
||||
: [lhs] "v"(lhs), [rhs] "v"(rhs));
|
||||
return result;
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
template <typename Problem_, typename Policy_ = UnifiedAttentionPipelineDefaultPolicy>
|
||||
struct UnifiedAttentionPipeline
|
||||
{
|
||||
@@ -377,23 +167,24 @@ struct UnifiedAttentionPipeline
|
||||
typename SAccElementFunction,
|
||||
typename PComputeElementFunction,
|
||||
typename OAccElementFunction>
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
[[maybe_unused]] const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
[[maybe_unused]] const VElementFunction& v_element_func,
|
||||
const index_t num_blocks,
|
||||
const index_t num_blocks_start,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
const index_t kv_page_size_in_blocks,
|
||||
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
CK_TILE_DEVICE auto operator()(
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp, // kBlockM * kHeadDimPadded tile
|
||||
const QElementFunction& q_element_func,
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // kPageBlockSize * kHeadDimPadded tile
|
||||
[[maybe_unused]] const KElementFunction& k_element_func,
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // kHeadDimPadded * kPageBlockSize tile
|
||||
[[maybe_unused]] const VElementFunction& v_element_func,
|
||||
const index_t num_blocks,
|
||||
const index_t num_blocks_start,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
const index_t kv_page_size_in_blocks,
|
||||
[[maybe_unused]] const SAccElementFunction& s_acc_element_func,
|
||||
const PComputeElementFunction& p_compute_element_func,
|
||||
const OAccElementFunction& o_acc_element_func,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
static_assert(
|
||||
@@ -1224,17 +1015,18 @@ struct UnifiedAttentionPipeline
|
||||
template <typename QDramBlockWindowTmp,
|
||||
typename KDramBlockWindowTmp,
|
||||
typename VDramBlockWindowTmp>
|
||||
CK_TILE_DEVICE auto operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
|
||||
const index_t num_blocks,
|
||||
const index_t num_blocks_start,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
const index_t kv_page_size_in_blocks,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
CK_TILE_DEVICE auto operator()(
|
||||
const QDramBlockWindowTmp& q_dram_block_window_tmp, // kBlockM * kHeadDimPadded tile
|
||||
const KDramBlockWindowTmp& k_dram_block_window_tmp, // kPageBlockSize * kHeadDimPadded tile
|
||||
const VDramBlockWindowTmp& v_dram_block_window_tmp, // kHeadDimPadded * kPageBlockSize tile
|
||||
const index_t num_blocks,
|
||||
const index_t num_blocks_start,
|
||||
const void* block_tables_ptr,
|
||||
index_t block_table_offset,
|
||||
const index_t kv_page_size_in_blocks,
|
||||
FmhaMask mask,
|
||||
float scale_s,
|
||||
void* smem_ptr) const
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user