mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-25 09:37:42 +00:00
temp save, change all instance to 1wave
This commit is contained in:
@@ -37,6 +37,12 @@ K0_MAX_SUBMAX_MAP = {
|
||||
256: 256
|
||||
}
|
||||
|
||||
SEQLENQ_MAP = {
|
||||
"16" : "16",
|
||||
"32" : "32",
|
||||
# "64" : "64"
|
||||
}
|
||||
|
||||
FMHA_FWD_DECODE_PIPELINE_MAP = {
|
||||
"decode_qr" : "ck_tile::BlockFmhaFwdDecodePipelineQRKSVS",
|
||||
}
|
||||
@@ -288,7 +294,7 @@ float fmha_fwd_decode(fmha_fwd_decode_traits t, fmha_fwd_decode_args a, const ck
|
||||
"""
|
||||
|
||||
FMHA_FWD_DECODE_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.do_fp8_static_quant == {F_squant}) &&
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck})) {{
|
||||
((a.block_table_ptr != nullptr) == {F_pagedkv}) && ({F_scheck}) && ({F_skcheck}) && ({F_dcheck})&& (a.seqlen_q <= {F_bm0}) && ({F_dvcheck})) {{
|
||||
using traits_ = fmha_fwd_decode_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_squant}, {F_pagedkv}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}>;
|
||||
|
||||
// get combine kernel tile sizes
|
||||
@@ -346,6 +352,7 @@ class FmhaFwdDecodeApiTrait:
|
||||
f'{self.vlayout}-{self.logits}-{self.mask}-{self.bias}-{self.lse}-{self.squant}-{self.spad}-{self.skpad}-{self.dpad}-'+\
|
||||
f'{self.dvpad}-{self.pagedkv}'
|
||||
|
||||
# sequence length as non-fast-changing dimension, we can always relay on instruction level OOB guard
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
@@ -362,12 +369,15 @@ class FmhaFwdDecodeApiTrait:
|
||||
else : return 'true'
|
||||
else: assert False
|
||||
|
||||
# head dimension as fast-changing dimension, we assume is multiple of 8
|
||||
@property
|
||||
def dcheck(self) -> str:
|
||||
if self.pipeline_tag in ['decode_qr', 'qr_nwarp_sshuffle']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
# if self.skpad == 't' : return 'true'
|
||||
# else : return 'true'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
@@ -376,6 +386,8 @@ class FmhaFwdDecodeApiTrait:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
# if self.skpad == 't' : return 'true'
|
||||
# else : return 'true'
|
||||
else: assert False
|
||||
|
||||
@dataclass
|
||||
@@ -637,19 +649,17 @@ class FmhaFwdSplitKVCombineKernel:
|
||||
def get_fmha_fwd_tile_dict_from_dtype(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
# '32' : FmhaFwdTileSize(32, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'64' : FmhaFwdTileSize(16, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '64' : FmhaFwdTileSize(32, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '64' : FmhaFwdTileSize(64, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '64' : FmhaFwdTileSize(128, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '64' : FmhaFwdTileSize(256, 64, 64, 64, 64, 64, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
### '96' : FmhaFwdTileSize(64, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1),
|
||||
'128' : FmhaFwdTileSize(16, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(32, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(64, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(128, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '128' : FmhaFwdTileSize(256, 64, 64, 128, 64, 128, 1, 4, 1, 1, 4, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
# '256' : FmhaFwdTileSize(64, 128, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'64': {
|
||||
# # Specialize for different SeqQ
|
||||
'16': FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64': FmhaFwdTileSize(64, 64, 64, 64, 64, 64, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
},
|
||||
'128': {
|
||||
'16': FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
'32': FmhaFwdTileSize(32, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
# '64': FmhaFwdTileSize(64, 64, 64, 128, 64, 128, 4, 1, 1, 4, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
},
|
||||
}
|
||||
else:
|
||||
return None
|
||||
@@ -684,6 +694,7 @@ def get_fwd_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> T
|
||||
for lse in ['t', 'f']:
|
||||
if hdim in [64, 128]: ### [32, 64, 96, 128]:
|
||||
pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, squant, pagedkv, mask))
|
||||
pipelines.append(Pipeline('decode_qr', 'row', 'f', 'f', 't', 't', logits, bias, lse, squant, pagedkv, mask))
|
||||
else:
|
||||
assert False
|
||||
else:
|
||||
@@ -698,8 +709,8 @@ def get_fwd_decode_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> T
|
||||
if d == None:
|
||||
continue
|
||||
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
|
||||
for hdim_str, mode in itertools.product(d.keys(), MODE_MAP.keys()):
|
||||
tile = d[hdim_str]
|
||||
for hdim_str, mode, seqlenq in itertools.product(d.keys(), MODE_MAP.keys(), SEQLENQ_MAP.keys()):
|
||||
tile = d[hdim_str][seqlenq]
|
||||
hdim = int(hdim_str)
|
||||
for pipeline in get_pipelines(dtype, hdim):
|
||||
if mode == "group":
|
||||
@@ -762,7 +773,7 @@ def get_fwd_decode_combine_blobs(kernel_filter : Optional[str], receipt) -> List
|
||||
pipelines = []
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
# for spad, dvpad, lse in itertools.product(["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
for spad, dvpad, lse in itertools.product(["f"], ["f"], ["t", "f"]):
|
||||
for spad, dvpad, lse in itertools.product(["f"], ["t", "f"], ["t", "f"]):
|
||||
pipelines.append(Pipeline('unused', spad, dvpad, lse, squant))
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse kernels
|
||||
|
||||
@@ -696,6 +696,15 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
ck_tile::FillUniformDistribution<VDataType>{0.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "exp" || init_method == "99")
|
||||
{
|
||||
ck_tile::FillUniformDistribution<QDataType>{0.f, 1.f, seed}(q_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(k_host);
|
||||
ck_tile::FillUniformDistribution<KDataType>{0.f, 1.f, seed}(knew_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(v_host);
|
||||
ck_tile::FillUniformDistribution<VDataType>{1.f, 1.f, seed}(vnew_host);
|
||||
ck_tile::FillUniformDistribution<BiasDataType>{0.f, 1.f, seed}(bias_host);
|
||||
}
|
||||
else if(init_method == "nf")
|
||||
{
|
||||
ck_tile::FillNormalDistribution<QDataType>{0.f, 3.f, seed}(q_host);
|
||||
@@ -1136,7 +1145,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
<< " GB/s" << std::flush << std::endl;
|
||||
|
||||
if(do_validation == 0)
|
||||
{
|
||||
|
||||
@@ -1315,6 +1315,17 @@ enum struct amd_buffer_coherence_enum
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
|
||||
@@ -1183,6 +1183,17 @@ enum struct amd_buffer_coherence_enum
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
|
||||
@@ -712,6 +712,7 @@ struct FmhaFwdDecodeKernel
|
||||
{
|
||||
// reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
|
||||
// hdim_q)
|
||||
// We expect Q data reuse among different KVSplited in decode case.
|
||||
const auto view = make_naive_tensor_view<address_space_enum::global>(
|
||||
q_ptr,
|
||||
make_tuple(kargs.nhead_ratio_qk, kargs.seqlen_q, kargs.hdim_q),
|
||||
@@ -755,7 +756,8 @@ struct FmhaFwdDecodeKernel
|
||||
}();
|
||||
|
||||
const auto make_k_dram = [&](const KDataType* data, index_t height) {
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
// We don't expect K data reuse among different blocks in decode case.
|
||||
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global, memory_operation_enum::set, amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(height, kargs.hdim_q),
|
||||
make_tuple(kargs.stride_k, 1),
|
||||
@@ -781,7 +783,8 @@ struct FmhaFwdDecodeKernel
|
||||
const auto make_v_dram = [&](const VDataType* data, index_t length) {
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
|
||||
// We don't expect V data reuse among different blocks in decode case.
|
||||
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global, memory_operation_enum::set, amd_buffer_coherence_enum::SYSTEM_NT1>(
|
||||
data, // will update this pointer if using paged-kvcache
|
||||
make_tuple(length, kargs.hdim_v),
|
||||
make_tuple(kargs.stride_v, 1),
|
||||
|
||||
@@ -44,6 +44,8 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
static constexpr index_t kK1 = BlockFmhaShape::kK1;
|
||||
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
|
||||
static constexpr index_t kSubQKHeaddim = BlockFmhaShape::kSubQKHeaddim;
|
||||
static constexpr index_t kNWarp = BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kNXdl = BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
|
||||
|
||||
static_assert(kSubQKHeaddim <= 256, "hdim bigger than 256 is not suitable for this pipeline!");
|
||||
|
||||
@@ -546,13 +548,21 @@ struct BlockFmhaFwdDecodePipelineQRKSVS
|
||||
}
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// In Nwarp=1 and NXdl=32, GEMM0 output naturally fit the input of GEMM1
|
||||
// Otherwise shuffle through LDS so that the tile layout is consistent with required by Gemm1
|
||||
auto s_new = [&](){
|
||||
if constexpr ( !((kNWarp==1) && (kNXdl == 32)) ){
|
||||
auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
|
||||
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
|
||||
// shuffle through LDS so that the tile layout is consistent with required by Gemm1
|
||||
store_tile(s_write_lds_window, s);
|
||||
block_sync_lds();
|
||||
auto s_new = load_tile(s_read_lds_window);
|
||||
store_tile(s_write_lds_window, s);
|
||||
block_sync_lds();
|
||||
return load_tile(s_read_lds_window);
|
||||
}
|
||||
else{
|
||||
return cast_tile<SMPLComputeDataType>(s_acc); // S{j}
|
||||
}
|
||||
}();
|
||||
|
||||
auto m_local = block_tile_reduce<SMPLComputeDataType>(
|
||||
s_new,
|
||||
|
||||
@@ -157,7 +157,7 @@ struct BlockFmhaFwdDecodePipelineQRKSVSDefaultPolicy
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
static_assert(MWarp == 1, "Check failed!");
|
||||
// static_assert(MWarp == 1, "Check failed!");
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
Reference in New Issue
Block a user