mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 18:17:44 +00:00
Re-enable optimization for gfx950 fmha fwd (#2671)
* Fix for fwd/bwd kernel build filter
* fix bwd code
* save an example for __bf16 type
* temp save, waiting for debug
* tempsave, fmha_decode
* temp save, change all instance to 1wave
* fix async copytest bug
* Add block_sync_lds_direct_load utility
* fix the s_waitcnt_imm calculation
* Improve s_waitcnt_imm calculation
* fix vmcnt shift
* add input validation and bug fix
* remove unnecessary output
* move test_copy into test
* temp save
* tempsave
* compile pass
* tempsave, trload+asyncload done
* tempsave. asynccopy+trload sanity checked
* remove unnecessary features
* fix the lds alignment caused performance regression
* enable prefill overload operator().
* remove all lds bankconflict with xor layouts
* enable larger tile size; upgrade xor pattern
* upgrade prefill pipeline; simple iglp; consistent data produce and consume order
* small refactor
* Load Q through lds, implement xor;
* add vmcnt guard before load ktile
* Add v_permlaneb32 for block_reduce. Disable it as it will cause un-coexecutable packed math in FA
* Add XOR fold strategy for hdim<128, but perf dropped; disable it by default; wait further perf debug
* add __restrict__ to tr load
* merge fa_decode pipeline into fmha_fwd api
* remove unnecessary files; rename some files
* Remove unnecessary changes
* bug fix, clang format;
* remove non-necessary change
* fix clangformat with 18.1.3
* fix bugs
* fix bug
* fix bug on non-gfx950
* fix bugs in gemm
* fix bug in pki4
* tempsave, update the blocksync functions
* change the warp setting for hdim32 fmha fwd
* clang format
* fix conflict. disable all v-col instance for fmha fwd
* Fix the bug
* clang format
* refactor blockgemm change, isolate to v2;
---------
Co-authored-by: Max Podkorytov <4273004+tenpercent@users.noreply.github.com>
Co-authored-by: asleepzzz <hanwen.chang@amd.com>
[ROCm/composable_kernel commit: 05a6e92705]
This commit is contained in:
@@ -115,6 +115,7 @@ PIPELINE_MAP = {
|
||||
"qr" : "ck_tile::BlockFmhaPipelineQRKSVS",
|
||||
"qr_async" : "ck_tile::BlockFmhaPipelineQRKSVSAsync",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineQSKSVS",
|
||||
"qr_async_trload" : "ck_tile::BlockFmhaPipelineQRKSVSAsyncTrload",
|
||||
}
|
||||
|
||||
PIPELINE_ENUM_MAP = {
|
||||
@@ -123,6 +124,7 @@ PIPELINE_ENUM_MAP = {
|
||||
"qr_nwarp_sshuffle" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qs" : "ck_tile::BlockFmhaPipelineEnum::QSKSVS",
|
||||
"qr_pagedkv" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS",
|
||||
"qr_async_trload" : "ck_tile::BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD",
|
||||
}
|
||||
|
||||
BOOL_MAP = {
|
||||
|
||||
@@ -12,6 +12,7 @@ from typing import List, Optional, Tuple
|
||||
|
||||
from codegen.cmake_config import *
|
||||
from codegen.cpp_symbol_map import *
|
||||
from codegen.utils import update_file
|
||||
|
||||
|
||||
DTYPE_BITS = {
|
||||
@@ -83,6 +84,7 @@ using fmha_pipeline_problem_{F_idx} = ck_tile::BlockFmhaPipelineProblem<
|
||||
{F_mode},
|
||||
fmha_variant_{F_idx},
|
||||
fmha_mask_{F_idx},
|
||||
{F_trload},
|
||||
fmha_trait_{F_idx}>;
|
||||
|
||||
using fmha_pipeline_{F_idx} = {F_pipeline}<
|
||||
@@ -97,7 +99,7 @@ using fmha_kernel_{F_idx} =
|
||||
ck_tile::FmhaFwdKernel<fmha_pipeline_{F_idx}, fmha_epilogue_{F_idx}>;
|
||||
|
||||
using trait_{F_idx} = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode},{F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout},
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
{F_pipeline_enum}, {F_logits}, fmha_mask_{F_idx}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
|
||||
#include <iostream>
|
||||
|
||||
@@ -161,12 +163,19 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
|
||||
[[maybe_unused]] auto get_num_blocks = [&](unsigned kM0) {{
|
||||
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
|
||||
}};
|
||||
|
||||
const bool has_load_tr = ck_tile::is_load_tr_supported();
|
||||
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_TRLOAD=""" {F_if}({F_trload_cond}){{
|
||||
{F_dtype_case}
|
||||
}}
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_PER_DTYPE=""" {F_if}(t.data_type.compare(\"{F_dtype}\") == 0){{
|
||||
{F_hdim_case}
|
||||
}}
|
||||
@@ -177,8 +186,8 @@ FMHA_FWD_API_PER_HDIM_CASE=""" {F_if} (t.hdim_q <= {F_hdim} && t.hdim_v <
|
||||
"""
|
||||
|
||||
FMHA_FWD_API_INNER_DISPATCH=""" {F_if}((t.is_group_mode == {F_mode}) && (t.is_v_rowmajor == {F_vlayout}) && (t.has_logits_soft_cap == {F_logits}) && ({F_mask_check}) && (t.bias_type == {F_bias_check}) && (t.has_lse == {F_lse}) && (t.has_dropout == {F_dropout}) && (t.do_fp8_static_quant == {F_squant}) && (t.skip_min_seqlen_q == {F_skip}) &&
|
||||
({F_scheck}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_skip}>;
|
||||
({F_scheck}) && ({F_seqtune}) && ({F_skcheck}) && ({F_dcheck}) && ({F_dvcheck}) && ({F_constraint})) {{
|
||||
using trait_ = fmha_fwd_traits_<{F_hdim}, {F_dtype}, {F_mode}, {F_bm0}, {F_bn0}, {F_bk0}, {F_bn1}, {F_bk1}, {F_bk0max}, {F_vlayout}, {F_pipeline_enum}, {F_logits}, {F_mask}, {F_bias}, {F_lse}, {F_dropout}, {F_squant}, {F_spad}, {F_skpad}, {F_dpad}, {F_dvpad}, {F_trload}, {F_skip}>;
|
||||
return fmha_fwd_<trait_>(s, a);
|
||||
}}
|
||||
"""
|
||||
@@ -221,6 +230,7 @@ class FmhaFwdApiTrait:
|
||||
dpad : str
|
||||
dvpad : str
|
||||
skip : str
|
||||
tr_load : str
|
||||
constraint : CppConstraint
|
||||
|
||||
@property
|
||||
@@ -231,13 +241,19 @@ class FmhaFwdApiTrait:
|
||||
@property
|
||||
def scheck(self) -> str:
|
||||
if self.mode == 'group': return 'true/*group mode spad always true*/' # group mode only generate spad/skpad == true
|
||||
if self.pipeline_tag == 'qr_async':
|
||||
if self.pipeline_tag in ['qr_async', 'qr_async_trload']:
|
||||
if self.spad == 't' : return 'true' # always support
|
||||
else : return 'true'
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.spad == 't' : return f'true /*a.seqlen_q % {self.bm0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_q % {self.bm0} == 0'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
def seqtune(self) -> str:
|
||||
if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true
|
||||
else:
|
||||
return f'a.seqlen_q <= {self.bm0}'
|
||||
|
||||
@property
|
||||
def skcheck(self) -> str:
|
||||
@@ -248,6 +264,9 @@ class FmhaFwdApiTrait:
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
if self.skpad == 't' : return f'true /*a.seqlen_k % {self.bn0} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.seqlen_k % {self.bn0} == 0'
|
||||
elif self.pipeline_tag == 'qr_async_trload':
|
||||
if self.skpad == 't' : return 'true'
|
||||
else: return 'true'
|
||||
else: assert False
|
||||
|
||||
@property
|
||||
@@ -256,7 +275,7 @@ class FmhaFwdApiTrait:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dpad == 't': return f'a.hdim_q % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dpad == 't': return f'true /*a.hdim_q % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_q % {bk0submax} == 0'
|
||||
@@ -268,7 +287,7 @@ class FmhaFwdApiTrait:
|
||||
vec = int((32 * 4) / DTYPE_BITS[self.dtype])
|
||||
if self.dvpad == 't': return f'a.hdim_v % {vec} == 0'
|
||||
else : assert False
|
||||
elif self.pipeline_tag in ['qr', 'qs']:
|
||||
elif self.pipeline_tag in ['qr', 'qs', 'qr_async_trload']:
|
||||
bk0submax = K0_MAX_SUBMAX_MAP[self.bk0max]
|
||||
if self.dvpad == 't': return f'true /*a.hdim_v % {bk0submax} != 0*/' # TODO: order of get_pipelines() matters! (ugly)
|
||||
else : return f'a.hdim_v % {bk0submax} == 0'
|
||||
@@ -290,6 +309,7 @@ class FmhaFwdPipeline:
|
||||
F_squant : str #
|
||||
F_mask : str # value from MASK_MAP
|
||||
F_skip : str # true/false
|
||||
F_trload : str # true/false
|
||||
F_constraint : CppConstraint = field(default_factory=lambda: CppConstraint())
|
||||
|
||||
@property
|
||||
@@ -331,6 +351,9 @@ class FmhaFwdPipeline:
|
||||
|
||||
if self.F_squant == 't' : n += '_squant'
|
||||
else: n += '_nsquant'
|
||||
|
||||
if self.F_trload == 't' : n += '_trload'
|
||||
else: n += '_ntrload'
|
||||
|
||||
return n
|
||||
|
||||
@@ -351,31 +374,39 @@ class FmhaFwdApiPool:
|
||||
|
||||
@property
|
||||
def api(self) -> str:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][(hdim, hdim_v)]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
if not per_dtypes:
|
||||
tr_load_cond_map = {
|
||||
"t": "has_load_tr",
|
||||
"f": "true"
|
||||
}
|
||||
|
||||
per_tr_load =str()
|
||||
for tr_load in ["t", "f"]:
|
||||
per_dtypes=str()
|
||||
for i, dtype in enumerate(self.pool.keys()):
|
||||
per_hdim_case=str()
|
||||
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
|
||||
traits=self.pool[dtype][(hdim, hdim_v)]
|
||||
inners=str()
|
||||
for k, trait in enumerate(traits):
|
||||
if_k = 'if' if k == 0 else 'else if'
|
||||
inners = inners + FMHA_FWD_API_INNER_DISPATCH.format(F_if=if_k, F_mode=MODE_MAP[trait.mode], F_vlayout=LAYOUT_MAP[trait.vlayout],
|
||||
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
|
||||
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
|
||||
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load],
|
||||
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
|
||||
F_constraint=trait.constraint,
|
||||
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
|
||||
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
|
||||
F_hdim=hdim, F_dtype=FWD_DTYPE_MAP[dtype])
|
||||
if_j = 'if' if j == 0 else 'else if'
|
||||
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim_v, F_inner_dispatch=inners)
|
||||
if_i = 'if' if i == 0 else 'else if'
|
||||
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
|
||||
per_tr_load += FMHA_FWD_API_PER_TRLOAD.format(F_if='if', F_trload_cond=tr_load_cond_map[tr_load], F_dtype_case=per_dtypes)
|
||||
if not per_tr_load:
|
||||
# empty string we add some ignore to suppress warning in api
|
||||
per_dtypes += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_dtypes)
|
||||
per_tr_load += ' (void)t ; (void)s ; (void)a;'
|
||||
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_API.format(F_dispatch = per_tr_load)
|
||||
|
||||
@dataclass
|
||||
class FmhaFwdTileSize:
|
||||
@@ -458,7 +489,8 @@ class FmhaFwdKernel:
|
||||
F_pipeline_enum = PIPELINE_ENUM_MAP[self.F_pipeline.tag],
|
||||
F_mask = get_mask_map(self.mask_impl)[self.F_pipeline.F_mask],
|
||||
F_mode = MODE_MAP[self.F_mode],
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag])
|
||||
F_pipeline = PIPELINE_MAP[self.F_pipeline.tag],
|
||||
F_trload = BOOL_MAP[self.F_pipeline.F_trload])
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -494,6 +526,7 @@ class FmhaFwdKernel:
|
||||
dpad=self.F_pipeline.F_dpad,
|
||||
dvpad=self.F_pipeline.F_dvpad,
|
||||
skip=self.F_pipeline.F_skip,
|
||||
tr_load=self.F_pipeline.F_trload,
|
||||
constraint=self.F_tile.F_constraint & self.F_pipeline.F_constraint)
|
||||
|
||||
class KernelComponentFactory:
|
||||
@@ -503,10 +536,15 @@ class KernelComponentFactory:
|
||||
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
|
||||
if dtype == 'fp16' or dtype == 'bf16':
|
||||
return {
|
||||
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 2, 1, 1, 2, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(64, 64) : [FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize(32, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(96, 128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(128,128) : [FmhaFwdTileSize(16, 32, 64, 128, 32, 128, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
|
||||
FmhaFwdTileSize(32, 32, 128, 128, 32, 128, 1, 1, 1, 1, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 64, 32, 128, 16, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1),
|
||||
FmhaFwdTileSize(128, 128, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
# (160,160) : [FmhaFwdTileSize(128, 128, 32, 160, 32, 160, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
(192,128) : [FmhaFwdTileSize(128, 128, 32, 128, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
|
||||
(192,192) : [FmhaFwdTileSize(128, 128, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, 1)],
|
||||
@@ -534,34 +572,27 @@ class KernelComponentFactory:
|
||||
if dtype in ['fp16', 'bf16']:
|
||||
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
|
||||
if hdim == 256 and hdim_v == 256:
|
||||
# if True:
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
# the below two is used for hdim vectorize load
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
else:
|
||||
if bias == "bias":
|
||||
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
else:
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'col', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
|
||||
if (hdim, hdim_v) in [(64, 64), (128, 128)] and logits == "f" and bias == "no" and dropout == "f" and lse == "f" and skip == "f":
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
pipelines.append(FmhaFwdPipeline('qr_async_trload', 'row', 'f', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 't'))
|
||||
if receipt == 1 and bias != "bias":
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 't', 'f', 't', 't', logits, bias, lse, dropout, squant, mask, skip)) # TODO: cover arbitraty hdim
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f')) # TODO: cover arbitraty hdim
|
||||
elif dtype in ['fp8', 'bf8']:
|
||||
# no need lse/dropout kernels
|
||||
for logits, mask, bias in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys()):
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f'))
|
||||
pipelines.append(FmhaFwdPipeline('qr', 'col', 'f', 'f', 'f', 'f', logits, bias, 'f', 'f', squant, mask, 'f', 'f'))
|
||||
elif dtype in ['fp8fp16', 'fp8bf16']:
|
||||
# TODO
|
||||
None
|
||||
@@ -599,6 +630,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
|
||||
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
|
||||
continue
|
||||
if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)):
|
||||
# non qr_async_trload only support km0=128 tile size when hdim is not 128
|
||||
# non qr_async only support kn0=128 tile size when hdim is 128
|
||||
continue
|
||||
if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])):
|
||||
continue
|
||||
# logits_soft_cap is only allowed if no bias
|
||||
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
|
||||
continue
|
||||
@@ -665,10 +702,10 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
|
||||
return (api_pool, gen)
|
||||
|
||||
def write_single_fwd_kernel(kernel: FmhaFwdKernel, autogen_dir: Path) -> None:
|
||||
(autogen_dir / kernel.filename).write_text(kernel.template)
|
||||
update_file(autogen_dir / kernel.filename, kernel.template)
|
||||
|
||||
def write_fwd_api(api_pool : FmhaFwdApiPool, autogen_dir: Path) -> None:
|
||||
(autogen_dir / FMHA_FWD_API_FILENAME).write_text(api_pool.api)
|
||||
update_file(autogen_dir / FMHA_FWD_API_FILENAME, api_pool.api)
|
||||
|
||||
def write_blobs(output_dir : Path, kernel_filter : str, receipt, optdim_list, mask_impl) -> None:
|
||||
api_pool, kernels = get_fwd_blobs(kernel_filter, receipt, optdim_list, mask_impl)
|
||||
|
||||
@@ -1135,7 +1135,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
|
||||
|
||||
std::cout << std::fixed << ", " << std::setprecision(3) << ave_time << " ms, "
|
||||
<< std::setprecision(2) << tflops << " TFlops, " << std::setprecision(2) << gb_per_sec
|
||||
<< " GB/s" << std::flush;
|
||||
<< " GB/s" << std::flush << std::endl;
|
||||
|
||||
if(do_validation == 0)
|
||||
{
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/fmha.hpp"
|
||||
@@ -1028,6 +1029,7 @@ template <ck_tile::index_t HDim_,
|
||||
bool kPadSK_,
|
||||
bool kPadD_,
|
||||
bool kPadDv_,
|
||||
bool kUseTrLoad_,
|
||||
bool kSkipMinSeqlenQ_ = false>
|
||||
struct fmha_fwd_traits_
|
||||
{
|
||||
@@ -1052,6 +1054,7 @@ struct fmha_fwd_traits_
|
||||
static constexpr bool kPadSK = kPadSK_;
|
||||
static constexpr bool kPadD = kPadD_;
|
||||
static constexpr bool kPadDv = kPadDv_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
static constexpr bool kSkipMinSeqlenQ = kSkipMinSeqlenQ_;
|
||||
};
|
||||
|
||||
|
||||
@@ -18,14 +18,3 @@ $EXE -prec=$prec -b=1 -h=$nhead -d=$hdim -s=16384 -iperm=$perm -operm=$perm -kn
|
||||
done
|
||||
done
|
||||
done
|
||||
|
||||
for perm in 0 1 ; do
|
||||
|
||||
$EXE -prec=fp8 -squant=1 -b=32 -h=16 -d=128 -s=512 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=fp8 -squant=1 -b=16 -h=16 -d=128 -s=1024 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=fp8 -squant=1 -b=8 -h=16 -d=128 -s=2048 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=fp8 -squant=1 -b=4 -h=16 -d=128 -s=4096 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=fp8 -squant=1 -b=2 -h=16 -d=128 -s=8192 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
|
||||
$EXE -prec=fp8 -squant=1 -b=1 -h=16 -d=128 -s=16384 -iperm=$perm -operm=$perm -vlayout=c -range_q=240 -range_k=240 -range_v=240 -range_p=240 -range_o=240 -kname=1 -v=$VALID ; sleep 3
|
||||
|
||||
done
|
||||
@@ -42,7 +42,6 @@ run_fp16_bf16_tests() {
|
||||
for prec in "fp16" "bf16" ; do
|
||||
for mode in 1 0 ; do
|
||||
for perm in 0 1 ; do
|
||||
for vlayout in "r" "c" ; do
|
||||
for hdim in 32 64 128 256 ; do
|
||||
for lse in 0 1 ; do
|
||||
for bias in "n" "e" "a" ; do
|
||||
@@ -51,16 +50,16 @@ run_fp16_bf16_tests() {
|
||||
for page_block_size in $PAGE_BLOCK_SIZE ; do
|
||||
for cache_batch_idx in $CACHE_BATCH_IDX ; do
|
||||
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -vlayout=$vlayout -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
# $EXE -prec=$prec -mode=$mode -b=1 -h=1 -d=$hdim -s=1024 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=2 -h_k=1 -d=16, -d_v=$hdim -s=55 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=3 -d=$hdim -s=100 -s_k=51 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=16 -d_v=$hdim -s=99 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=1 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1024 -s_k=256 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -d_v=24 -s=3 -s_k=99 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=3 -h=2 -h_k=1 -d=$hdim -s=200 -s_k=520 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=t:128,30 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=2 -h=1 -d=$hdim -s=99 -s_k=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=b:4,35 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=33 -s_k=0 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
$EXE -prec=$prec -mode=$mode -b=1 -h=2 -h_k=1 -d=$hdim -s=1 -s_k=10 -s_kpad=32 -bias=$bias -p_drop=$p_drop -lse=$lse -iperm=$perm -operm=$perm -mask=2 -num_splits=$num_splits -page_block_size=$page_block_size -cache_batch_idx=$cache_batch_idx -kname=$KNAME $COMMON_ARGS
|
||||
|
||||
done ; done ; done ; done ; done
|
||||
done ; done ; done ; done ; done
|
||||
|
||||
@@ -41,10 +41,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -1318,6 +1314,17 @@ enum struct amd_buffer_coherence_enum
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
@@ -2756,7 +2763,7 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
|
||||
@@ -32,10 +32,6 @@ CK_TILE_DEVICE int32x4_t make_wave_buffer_resource(const void* ptr, uint32_t siz
|
||||
{
|
||||
buffer_resource res{ptr, size, CK_TILE_BUFFER_RESOURCE_3RD_DWORD};
|
||||
int32x4_t r = __builtin_bit_cast(int32x4_t, res);
|
||||
r.x = __builtin_amdgcn_readfirstlane(r.x);
|
||||
r.y = __builtin_amdgcn_readfirstlane(r.y);
|
||||
r.z = __builtin_amdgcn_readfirstlane(r.z);
|
||||
r.w = __builtin_amdgcn_readfirstlane(r.w);
|
||||
return r;
|
||||
}
|
||||
|
||||
@@ -1186,6 +1182,17 @@ enum struct amd_buffer_coherence_enum
|
||||
glc = 1,
|
||||
slc = 2,
|
||||
glc_slc = 3,
|
||||
// gfx94: bit 0 = sc0, bit 1 = nt, bit 3 = swz, bit 4 = sc1
|
||||
// SC[1:0] System Cache level: 0=wave, 1=group, 2=device, 3=system
|
||||
// NT Non-Temporal: 0=expect temporal reuse; 1=do not expect temporal reuse
|
||||
WAVE_NT0 = 0,
|
||||
WAVE_NT1 = 2,
|
||||
GROUP_NT0 = 1,
|
||||
GROUP_NT1 = 3,
|
||||
DEVICE_NT0 = 8,
|
||||
DEVICE_NT1 = 10,
|
||||
SYSTEM_NT0 = 9,
|
||||
SYSTEM_NT1 = 11,
|
||||
};
|
||||
|
||||
template <index_t N,
|
||||
@@ -2574,7 +2581,7 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
|
||||
|
||||
#if defined(__gfx950__)
|
||||
template <typename T, index_t N, address_space_enum BufferAddressSpace>
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* in_ptr)
|
||||
__device__ auto amd_transpose_load_to_vgpr(const T* __restrict__ in_ptr)
|
||||
{
|
||||
|
||||
static_assert(__has_builtin(__builtin_amdgcn_raw_buffer_load_b32),
|
||||
|
||||
@@ -89,21 +89,6 @@ CK_TILE_DEVICE index_t get_thread_id() { return threadIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE index_t get_block_id() { return blockIdx.x; }
|
||||
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
#if CK_TILE_EXPERIMENTAL_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
// asm volatile("\
|
||||
// s_waitcnt lgkmcnt(0) \n \
|
||||
// s_barrier \
|
||||
// " ::);
|
||||
|
||||
__builtin_amdgcn_s_waitcnt(0xc07f);
|
||||
__builtin_amdgcn_s_barrier();
|
||||
#else
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void block_sync_load_raw(index_t cnt = 0)
|
||||
{
|
||||
#ifdef __gfx12__
|
||||
@@ -174,6 +159,18 @@ CK_TILE_DEVICE void s_waitcnt_barrier()
|
||||
__builtin_amdgcn_s_barrier();
|
||||
}
|
||||
|
||||
template <index_t lgkmcnt = 0>
|
||||
CK_TILE_DEVICE void block_sync_lds()
|
||||
{
|
||||
s_waitcnt_barrier<waitcnt_arg::kMaxVmCnt, waitcnt_arg::kMaxExpCnt, lgkmcnt>();
|
||||
}
|
||||
|
||||
template <index_t vmcnt = 0>
|
||||
CK_TILE_DEVICE void block_sync_lds_direct_load()
|
||||
{
|
||||
s_waitcnt_barrier<vmcnt, waitcnt_arg::kMaxExpCnt, waitcnt_arg::kMaxLgkmCnt>();
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void s_nop(index_t cnt = 0)
|
||||
{
|
||||
#if 1
|
||||
|
||||
@@ -59,6 +59,21 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE auto warp_shuffle_down_pair(const T& v_local)
|
||||
{
|
||||
static_assert(sizeof(T) == sizeof(int32_t), "wrong!");
|
||||
|
||||
const int32x2_t x = __builtin_amdgcn_permlane32_swap(
|
||||
bit_cast<int32_t>(v_local), bit_cast<int32_t>(v_local), false, false);
|
||||
|
||||
thread_buffer<T, 2> v;
|
||||
v(0) = bit_cast<T>(x[0]);
|
||||
v(1) = bit_cast<T>(x[1]);
|
||||
|
||||
return v;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_DEVICE T warp_shuffle(const T& v_local, uint32_t src_lane)
|
||||
{
|
||||
|
||||
@@ -191,6 +191,16 @@
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// use llvm builtin bf16 data type after ROCm 6.5
|
||||
#ifndef CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 5 && HIP_VERSION_PATCH >= 50421) || \
|
||||
(HIP_VERSION_MAJOR >= 7)
|
||||
#define CK_TILE_USE_LLVM_BUILTIN_BF16 1
|
||||
#else
|
||||
#define CK_TILE_USE_LLVM_BUILTIN_BF16 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifndef CK_TILE_DEBUG_LOG
|
||||
#define CK_TILE_DEBUG_LOG 0
|
||||
#endif
|
||||
|
||||
@@ -6,6 +6,9 @@
|
||||
#include "ck_tile/core/numeric/half.hpp"
|
||||
#include "ck_tile/core/numeric/integral_constant.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
#include <hip/hip_bfloat16.h>
|
||||
#endif
|
||||
#include <stdint.h>
|
||||
|
||||
#pragma once
|
||||
@@ -102,7 +105,11 @@ struct native_t<bfloat16_t>
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = typename bf16_t::raw_type;
|
||||
#else
|
||||
#if CK_TILE_USE_LLVM_BUILTIN_BF16
|
||||
using bfloat16_t = __bf16;
|
||||
#else
|
||||
using bfloat16_t = ushort;
|
||||
#endif
|
||||
using bf16_t = bfloat16_t;
|
||||
using bf16_raw_t = uint16_t;
|
||||
#endif
|
||||
@@ -280,7 +287,11 @@ template <bf16_rounding_mode rounding =
|
||||
static_cast<bf16_rounding_mode>(CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT)>
|
||||
CK_TILE_HOST_DEVICE constexpr bfloat16_t float_to_bf16(float f, constant<rounding> = {})
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return static_cast<bfloat16_t>(f);
|
||||
#else
|
||||
return bit_cast<bfloat16_t>(float_to_bf16_raw(f, constant<rounding>{}));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <bf16_rounding_mode rounding =
|
||||
|
||||
@@ -21,7 +21,7 @@ namespace ck_tile {
|
||||
using fp32_t = float;
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr uint8_t float_to_e2m1(float x, float scale = 1.f);
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ struct numeric_traits<pk_int4_t>
|
||||
|
||||
using fp32x2_t = float __attribute__((ext_vector_type(2)));
|
||||
using fp16x2_t = _Float16 __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
CK_TILE_HOST_DEVICE fp32x2_t pk_int4_t_to_fp32x2_t(const pk_int4_t& x)
|
||||
{
|
||||
|
||||
@@ -131,12 +131,12 @@ using fp16x64_t = _Float16 __attribute__((ext_vector_type(64)));
|
||||
|
||||
// bf16
|
||||
// using bf16_t = ...
|
||||
using bf16x2_t = bf16_raw_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bf16_raw_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bf16_raw_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bf16_raw_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bf16_raw_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bf16_raw_t __attribute__((ext_vector_type(64)));
|
||||
using bf16x2_t = bfloat16_t __attribute__((ext_vector_type(2)));
|
||||
using bf16x4_t = bfloat16_t __attribute__((ext_vector_type(4)));
|
||||
using bf16x8_t = bfloat16_t __attribute__((ext_vector_type(8)));
|
||||
using bf16x16_t = bfloat16_t __attribute__((ext_vector_type(16)));
|
||||
using bf16x32_t = bfloat16_t __attribute__((ext_vector_type(32)));
|
||||
using bf16x64_t = bfloat16_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
// i32
|
||||
// using int32_t = ...
|
||||
|
||||
@@ -52,6 +52,8 @@
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_trload_policy.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -11,6 +11,7 @@ enum class BlockFmhaPipelineEnum
|
||||
QRKSVS = 0,
|
||||
QRKSVS_ASYNC,
|
||||
QSKSVS,
|
||||
QRKSVS_ASYNC_TRLOAD,
|
||||
};
|
||||
|
||||
template <BlockFmhaPipelineEnum>
|
||||
@@ -32,4 +33,10 @@ struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QSKSVS>
|
||||
static constexpr const char* name = "qs";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct BlockFmhaPipelineEnumToStr<BlockFmhaPipelineEnum::QRKSVS_ASYNC_TRLOAD>
|
||||
{
|
||||
static constexpr const char* name = "qr_async_trload";
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -22,6 +22,7 @@ template <typename QDataType_,
|
||||
bool kIsGroupMode_,
|
||||
typename AttentionVariant_,
|
||||
typename FmhaMask_,
|
||||
bool kUseTrLoad_,
|
||||
typename Traits_>
|
||||
struct BlockFmhaPipelineProblem
|
||||
{
|
||||
@@ -46,6 +47,7 @@ struct BlockFmhaPipelineProblem
|
||||
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kIsGroupMode = kIsGroupMode_;
|
||||
static constexpr bool kUseTrLoad = kUseTrLoad_;
|
||||
|
||||
// attributes from traits
|
||||
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,821 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp"
|
||||
|
||||
// can remove all bank conflicts, but drop the performance for some cases
|
||||
// Probably it is limited by compiler optimization.
|
||||
#define CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD 0
|
||||
namespace ck_tile {
|
||||
// This pipeline is qkv all located in LDS
|
||||
struct BlockFmhaPipelineQRKSVSAsyncTrloadDefaultPolicy
|
||||
: BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>
|
||||
{
|
||||
using BasePolicy = BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
|
||||
/* AsyncCopy = */ false,
|
||||
/* NumPrefetchK = */ 1,
|
||||
/* NumPrefetchV = */ 1>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
|
||||
// this should align with MakeQDramTileDistribution()
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOacc()
|
||||
{
|
||||
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
|
||||
|
||||
return static_cast<index_t>(16 / sizeof(OaccDataType));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::KDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
return min(ElemPerThread, MaxVectorSize);
|
||||
}
|
||||
|
||||
template <typename Problem, bool BypassLDS = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
|
||||
{
|
||||
if constexpr(!BypassLDS)
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::QDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kMPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t KPerThread = kMaxVecLoad;
|
||||
constexpr index_t KThreads = kKPerBlock / KPerThread;
|
||||
constexpr index_t MThreadPerWarp = get_warp_size() / KThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t MPerThread = kMPerBlock / (MThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
|
||||
sequence<KThreads, KPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
constexpr auto q_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
|
||||
|
||||
return q_block_dstr;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, bool LoadOnce = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
|
||||
{
|
||||
using KDataType = remove_cvref_t<typename Problem::KDataType>;
|
||||
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock =
|
||||
LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(KDataType);
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
|
||||
constexpr index_t K1 = min(MaxVectorSize, ElemPerThread);
|
||||
constexpr index_t K0 = kKPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N0 = kNPerBlock / (N2 * N1);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
// Read M first, then K
|
||||
// This is the same data consume order as BlockGEMM
|
||||
constexpr auto q_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
|
||||
|
||||
return q_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
|
||||
{
|
||||
// TODO: this is for 3d layout
|
||||
using QDataType = remove_cvref_t<typename Problem::QDataType>;
|
||||
return static_cast<index_t>(16 / sizeof(QDataType));
|
||||
}
|
||||
|
||||
template <typename Problem, bool Xor = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kSubQKHeaddim;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
|
||||
|
||||
constexpr auto q_lds_block_desc = [&]() {
|
||||
if constexpr(Xor)
|
||||
{
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::QDataType);
|
||||
constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
|
||||
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / kKPack>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<LDSLayerSize>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_naive,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kMPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto q_lds_block_desc_tmp = transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kMPerBlock / XorLengthFold>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<XorLengthFold>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
q_lds_block_desc_tmp,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kMPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
constexpr auto q_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
number<kMPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto q_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
q_lds_block_desc_naive,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kMPerBlock>{},
|
||||
number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
q_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return q_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem, bool LoadOnce = false, bool Xor = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock =
|
||||
LoadOnce ? Problem::BlockFmhaShape::kSubQKHeaddim : Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackK<Problem>();
|
||||
|
||||
constexpr auto k_lds_block_desc = [&]() {
|
||||
if constexpr(Xor)
|
||||
{
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::KDataType);
|
||||
constexpr auto XorLengthFold = LDSLayerSize / kKPerBlock;
|
||||
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / kKPack>{},
|
||||
number<kKPack>{}),
|
||||
make_tuple(number<LDSLayerSize>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
k_lds_block_desc_naive,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kNPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto k_lds_block_desc_tmp = transform_tensor_descriptor(
|
||||
k_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kNPerBlock / XorLengthFold>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(number<XorLengthFold>{}, number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
k_lds_block_desc_tmp,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<kNPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
constexpr auto k_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(
|
||||
number<kNPerBlock>{}, number<kKPerBlock / kKPack>{}, number<kKPack>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<kKPack>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto k_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
k_lds_block_desc_naive,
|
||||
make_tuple(make_xor_transform(make_tuple(number<kNPerBlock>{},
|
||||
number<kKPerBlock / kKPack>{})),
|
||||
make_pass_through_transform(number<kKPack>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
k_lds_block_desc_permuted,
|
||||
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kKPerBlock / kKPack>{}, number<kKPack>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
make_tuple(number<kKPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return k_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem, bool Xor = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t kKPack = GetSmemKPackV<Problem>();
|
||||
|
||||
constexpr auto v_lds_block_desc = [&]() {
|
||||
if constexpr(Xor)
|
||||
{
|
||||
constexpr auto XorGroupSize =
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{});
|
||||
|
||||
#if CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
constexpr auto LDSLayerSize = 256 / sizeof(typename Problem::VDataType);
|
||||
constexpr auto XorLengthFold = LDSLayerSize / kNPerBlock;
|
||||
|
||||
if constexpr(XorLengthFold > 1)
|
||||
{
|
||||
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / XorGroupSize>{},
|
||||
number<XorGroupSize>{}),
|
||||
make_tuple(number<LDSLayerSize>{}, number<XorGroupSize>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
v_lds_block_desc_naive,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<kKPerBlock / XorLengthFold>{},
|
||||
number<LDSLayerSize / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto v_lds_block_desc_tmp = transform_tensor_descriptor(
|
||||
v_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kKPerBlock / XorLengthFold>{}),
|
||||
make_unmerge_transform(make_tuple(number<XorLengthFold>{},
|
||||
number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}, sequence<3>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
v_lds_block_desc_tmp,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kKPerBlock / XorLengthFold>{}, number<XorLengthFold>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
else
|
||||
#endif // CK_TILE_FMHA_HANDLE_XOR_LENGTH_FOLD
|
||||
{
|
||||
constexpr auto v_lds_block_desc_naive = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock>{},
|
||||
number<kNPerBlock / XorGroupSize>{},
|
||||
number<XorGroupSize>{}),
|
||||
make_tuple(number<kNPerBlock>{}, number<XorGroupSize>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto v_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
v_lds_block_desc_naive,
|
||||
make_tuple(make_xor_transform(make_tuple(
|
||||
number<kKPerBlock>{}, number<kNPerBlock / XorGroupSize>{})),
|
||||
make_pass_through_transform(number<XorGroupSize>{})),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 1>{}, sequence<2>{}));
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
v_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kKPerBlock>{}),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<kNPerBlock / XorGroupSize>{}, number<XorGroupSize>{}))),
|
||||
make_tuple(sequence<0>{}, sequence<1, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{}),
|
||||
make_tuple(number<kNPerBlock>{}, number<1>{}),
|
||||
number<kKPack>{},
|
||||
number<1>{});
|
||||
}
|
||||
}();
|
||||
|
||||
return v_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN0,
|
||||
Problem::BlockFmhaShape::kK0>,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
|
||||
|
||||
using WarpGemm =
|
||||
WarpGemmMfmaDispatcher<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
|
||||
true>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::QDataType,
|
||||
typename Problem::KDataType,
|
||||
typename Problem::SaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
|
||||
WarpGemm,
|
||||
GemmLoopOrder::MNK>;
|
||||
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetPVBlockGemm()
|
||||
{
|
||||
using GemmProblem =
|
||||
BlockGemmProblem<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::kBlockSize,
|
||||
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
|
||||
Problem::BlockFmhaShape::kN1,
|
||||
Problem::BlockFmhaShape::kK1>,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
|
||||
|
||||
using WarpGemm = WarpGemmMfmaDispatcher<
|
||||
typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
((Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 16 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 32) ||
|
||||
(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}) == 32 &&
|
||||
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}) == 16))
|
||||
? WGAttrNumAccessEnum::Double
|
||||
: WGAttrNumAccessEnum::Single>;
|
||||
|
||||
using BlockGemmPolicy =
|
||||
BlockGemmARegBRegCRegV2CustomPolicy<typename Problem::PDataType,
|
||||
typename Problem::VDataType,
|
||||
typename Problem::OaccDataType,
|
||||
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
|
||||
WarpGemm,
|
||||
GemmLoopOrder::KMN>;
|
||||
|
||||
return BlockGemmARegBRegCRegV2<GemmProblem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
// Read N first, then K
|
||||
// This is the same data consume order as BlockGEMM
|
||||
constexpr auto k_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
|
||||
|
||||
return k_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeVDramTileDistribution()
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t MaxVectorSize = 16 / sizeof(typename Problem::VDataType);
|
||||
|
||||
constexpr index_t ElemPerThread = (kNPerBlock * kKPerBlock) / kBlockSize;
|
||||
static_assert(0 < ElemPerThread);
|
||||
constexpr index_t kMaxVecLoad = min(ElemPerThread, MaxVectorSize);
|
||||
|
||||
constexpr index_t NPerThread = kMaxVecLoad;
|
||||
constexpr index_t NThreads = kNPerBlock / NPerThread;
|
||||
constexpr index_t KThreadPerWarp = get_warp_size() / NThreads;
|
||||
constexpr index_t NumWarps = kBlockSize / get_warp_size();
|
||||
constexpr index_t KPerThread = kKPerBlock / (KThreadPerWarp * NumWarps);
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<KPerThread, NumWarps, KThreadPerWarp>,
|
||||
sequence<NThreads, NPerThread>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
// Read M first, then K
|
||||
// This is the same data consume order as BlockGEMM
|
||||
constexpr auto p_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto p_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
p_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
constexpr auto p_block_dstr = make_static_tile_distribution(p_block_dstr_encode);
|
||||
|
||||
return p_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetPVBlockGemm<Problem>())>;
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
|
||||
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
|
||||
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
|
||||
|
||||
// Read N first, then K
|
||||
// This is the same data consume order as BlockGEMM
|
||||
constexpr auto v_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
constexpr auto v_block_dstr =
|
||||
make_static_tile_distribution(typename InputTileDistributionTraits<
|
||||
decltype(v_block_dstr_encode),
|
||||
typename Problem::VDataType>::TransposedDstrEncode{});
|
||||
|
||||
return v_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemNPackS()
|
||||
{
|
||||
using SDataType = remove_cvref_t<typename Problem::SaccDataType>;
|
||||
return static_cast<index_t>(16 / sizeof(SDataType));
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeSLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
|
||||
constexpr index_t kNPack = GetSmemNPackS<Problem>();
|
||||
|
||||
constexpr auto s_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kNPerBlock / kNPack>{}, number<kMPerBlock>{}, number<kNPack>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * kNPack>{}, number<kNPack>{}, number<1>{}),
|
||||
number<kNPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto s_lds_block_desc = transform_tensor_descriptor(
|
||||
s_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<kMPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<kNPerBlock / kNPack>{}, number<kNPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return s_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeSRegTileDistribution()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetKVBlockGemm<Problem>())>;
|
||||
|
||||
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WG = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
constexpr index_t MWarp = config.template at<1>();
|
||||
constexpr index_t NWarp = config.template at<2>();
|
||||
|
||||
// static_assert(MWarp == 1, "Check failed!");
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t kTileK = Problem::BlockFmhaShape::kN0;
|
||||
|
||||
// K2 is equal to Impl::kABKPerLane * kKIterPerWarpGemm
|
||||
constexpr index_t K3 = WG::kK / WG::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t K2 = WG::WarpGemmAttribute::Impl::kABKLane;
|
||||
constexpr index_t K1 = kKPerBlock / (K2 * K3);
|
||||
constexpr index_t K0 = kTileK / kKPerBlock;
|
||||
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kAMLane;
|
||||
constexpr index_t M1 = MWarp;
|
||||
constexpr index_t M0 = kMPerBlock / (M2 * M1);
|
||||
|
||||
constexpr auto s2_block_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 1>>,
|
||||
tuple<sequence<1, 0>, sequence<2, 2>>,
|
||||
sequence<1, 2, 2, 2>,
|
||||
sequence<0, 0, 1, 3>>{};
|
||||
|
||||
constexpr auto s2_block_dstr = make_static_tile_distribution(s2_block_dstr_encoding);
|
||||
|
||||
return s2_block_dstr;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ()
|
||||
{
|
||||
return MakeQLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::QDataType);
|
||||
}
|
||||
|
||||
template <typename Problem, bool LoadOnce = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK()
|
||||
{
|
||||
return MakeKLdsBlockDescriptor<Problem, LoadOnce>().get_element_space_size() *
|
||||
sizeof(typename Problem::KDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV()
|
||||
{
|
||||
return MakeVLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::VDataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeS()
|
||||
{
|
||||
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
|
||||
|
||||
return NWarp > 1 ? MakeSLdsBlockDescriptor<Problem>().get_element_space_size() *
|
||||
sizeof(typename Problem::SaccDataType)
|
||||
: 0;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
|
||||
{
|
||||
// Alignment on gfx950 is 1280 Bytes
|
||||
// Alignment before gfx950 is 512 Bytes.
|
||||
return max(GetSmemSizeQ<Problem>(),
|
||||
GetSmemSizeK<Problem>() + GetSmemSizeS<Problem>() + GetSmemSizeV<Problem>());
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -383,23 +383,31 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
|
||||
{
|
||||
// TODO: this is for 3d layout
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
return 16 / sizeof(VDataType);
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
|
||||
return kMaxVecLoad;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
|
||||
{
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
using VLayout = remove_cvref_t<typename Problem::BlockFmhaShape::VLayout>;
|
||||
using VDataType = remove_cvref_t<typename Problem::VDataType>;
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
|
||||
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN1;
|
||||
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
constexpr index_t kMaxVecLoad =
|
||||
min(total_pixels, static_cast<index_t>(16 / sizeof(VDataType)));
|
||||
constexpr index_t kMinVecLoad = 4 / sizeof(VDataType);
|
||||
|
||||
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
|
||||
@@ -410,7 +418,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
|
||||
}
|
||||
else
|
||||
{
|
||||
return 16 / sizeof(VDataType);
|
||||
return kMaxVecLoad;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
372
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
Normal file
372
include/ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2.hpp
Normal file
@@ -0,0 +1,372 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v2_custom_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// This BlockGemm enhanced the control over inst issue order
|
||||
// A is block distributed tensor
|
||||
// B is block distributed tensor
|
||||
// C is block distributed tensor
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct BlockGemmARegBRegCRegV2
|
||||
{
|
||||
private:
|
||||
template <typename PipelineProblem_, typename GemmPolicy_>
|
||||
struct GemmTraits_
|
||||
{
|
||||
using Problem = remove_cvref_t<PipelineProblem_>;
|
||||
using Policy = remove_cvref_t<GemmPolicy_>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
|
||||
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
|
||||
|
||||
static constexpr index_t MWarp = config.template at<1>();
|
||||
static constexpr index_t NWarp = config.template at<2>();
|
||||
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
|
||||
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
|
||||
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
|
||||
|
||||
static constexpr auto BlockGemmLoopOrder = Policy::BlockGemmLoopOrder;
|
||||
|
||||
static constexpr index_t KPack = WarpGemm::kKPerThread;
|
||||
};
|
||||
|
||||
public:
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using Traits = GemmTraits_<Problem, Policy>;
|
||||
|
||||
using WarpGemm = typename Traits::WarpGemm;
|
||||
using BlockGemmShape = typename Traits::BlockGemmShape;
|
||||
static constexpr auto BlockGemmLoopOrder = Traits::BlockGemmLoopOrder;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Traits::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Traits::BDataType>;
|
||||
using CDataType = remove_cvref_t<typename Traits::CDataType>;
|
||||
|
||||
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
|
||||
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
|
||||
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
|
||||
|
||||
static constexpr index_t MWarp = Traits::MWarp;
|
||||
static constexpr index_t NWarp = Traits::NWarp;
|
||||
static constexpr bool UseDefaultScheduler = (Problem::NumWaveGroups != 1);
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
constexpr auto a_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<NWarp>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
tuple<sequence<1, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
|
||||
|
||||
return a_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeBBlockDistributionEncode()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding =
|
||||
tile_distribution_encoding<sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<2, 1>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
constexpr auto b_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
tuple<sequence<0, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
|
||||
return b_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockDistributionEncode()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
|
||||
return c_block_dstr_encode;
|
||||
}
|
||||
}
|
||||
|
||||
// C += A * B
|
||||
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
|
||||
const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
|
||||
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
|
||||
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
// check ABC-block-distribution
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeABlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"A distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeBBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"B distribution is wrong!");
|
||||
static_assert(
|
||||
std::is_same_v<remove_cvref_t<decltype(MakeCBlockDistributionEncode())>,
|
||||
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding())>>,
|
||||
"C distribution is wrong!");
|
||||
|
||||
using AWarpDstr = typename WarpGemm::AWarpDstr;
|
||||
using BWarpDstr = typename WarpGemm::BWarpDstr;
|
||||
using CWarpDstr = typename WarpGemm::CWarpDstr;
|
||||
|
||||
using AWarpTensor = typename WarpGemm::AWarpTensor;
|
||||
using BWarpTensor = typename WarpGemm::BWarpTensor;
|
||||
using CWarpTensor = typename WarpGemm::CWarpTensor;
|
||||
|
||||
constexpr auto a_warp_y_lengths =
|
||||
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto b_warp_y_lengths =
|
||||
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
constexpr auto c_warp_y_lengths =
|
||||
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
|
||||
|
||||
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
|
||||
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
|
||||
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
|
||||
|
||||
// hot loop:
|
||||
if constexpr(BlockGemmLoopOrder == GemmLoopOrder::KMN)
|
||||
{
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, mIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<kIter, nIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
CWarpTensor c_warp_tensor;
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
else if constexpr(BlockGemmLoopOrder == GemmLoopOrder::MNK)
|
||||
{
|
||||
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
|
||||
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
|
||||
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
|
||||
// read A warp tensor from A Block window
|
||||
AWarpTensor a_warp_tensor;
|
||||
|
||||
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
|
||||
|
||||
// read B warp tensor from B block tensor
|
||||
BWarpTensor b_warp_tensor;
|
||||
|
||||
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
|
||||
|
||||
// read C warp tensor from C block tensor
|
||||
CWarpTensor c_warp_tensor;
|
||||
|
||||
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
|
||||
|
||||
// warp GEMM
|
||||
WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
|
||||
|
||||
// write C warp tensor into C block tensor
|
||||
c_block_tensor.set_y_sliced_thread_data(
|
||||
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
|
||||
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
|
||||
c_warp_tensor.get_thread_buffer());
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
|
||||
{
|
||||
if constexpr(UseDefaultScheduler)
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<MWarp>,
|
||||
tuple<sequence<MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<>,
|
||||
tuple<>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
|
||||
sequence<>,
|
||||
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
|
||||
tuple<sequence<1, 2>>,
|
||||
tuple<sequence<1, 1>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{};
|
||||
|
||||
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
|
||||
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
|
||||
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
|
||||
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
|
||||
return c_block_tensor;
|
||||
}
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
template <typename ABlockTensor, typename BBlockTensor>
|
||||
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
|
||||
const BBlockTensor& b_block_tensor) const
|
||||
{
|
||||
auto c_block_tensor = MakeCBlockTile();
|
||||
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
|
||||
return c_block_tensor;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,45 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct GemmLoopOrder
|
||||
{
|
||||
KMN,
|
||||
MNK,
|
||||
};
|
||||
|
||||
template <typename AType_,
|
||||
typename BType_,
|
||||
typename CType_,
|
||||
typename BlockWarps_,
|
||||
typename WarpGemm_,
|
||||
GemmLoopOrder BlockGemmLoopOrder_ = GemmLoopOrder::KMN>
|
||||
struct BlockGemmARegBRegCRegV2CustomPolicy
|
||||
{
|
||||
using AType = remove_cvref_t<AType_>;
|
||||
using BType = remove_cvref_t<BType_>;
|
||||
using CType = remove_cvref_t<CType_>;
|
||||
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
|
||||
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
|
||||
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
|
||||
|
||||
using WarpGemm = remove_cvref_t<WarpGemm_>;
|
||||
|
||||
static constexpr auto BlockGemmLoopOrder = BlockGemmLoopOrder_;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -104,6 +104,10 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32SwizzleBTransposedCDistribution =
|
||||
1>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
@@ -210,6 +214,10 @@ using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
|
||||
AttrNumAccess>>;
|
||||
#endif
|
||||
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
|
||||
|
||||
#if defined(__gfx950__)
|
||||
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
|
||||
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB<
|
||||
|
||||
@@ -45,6 +45,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp16 2:4 structural sparsity
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
@@ -74,6 +76,8 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
|
||||
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleBTransposedCDistribution; };
|
||||
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution; };
|
||||
|
||||
// fp8
|
||||
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
|
||||
|
||||
@@ -14,10 +14,14 @@ namespace ck_tile {
|
||||
* Y dim must have at least one dim not been reduced
|
||||
*/
|
||||
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
|
||||
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
|
||||
template <typename AccDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
bool WithBroadcast = true,
|
||||
bool CrossWarp = true>
|
||||
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
bool_constant<WithBroadcast> = {})
|
||||
bool_constant<WithBroadcast> = {},
|
||||
bool_constant<CrossWarp> = {})
|
||||
{
|
||||
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
@@ -56,14 +60,24 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
constexpr index_t lid_delta =
|
||||
lid_over_rid_derivative * (1 << (nstage - istage - 1));
|
||||
if constexpr(CrossWarp)
|
||||
{
|
||||
constexpr index_t lid_delta =
|
||||
lid_over_rid_derivative * (1 << (nstage - istage - 1));
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle_down(v_local, lid_delta);
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
// pull data from remote lane
|
||||
const auto v_swapped_regs = warp_shuffle_down_pair(v_local);
|
||||
// reduce
|
||||
v_local = reduce_func(v_swapped_regs.at(0), v_swapped_regs.at(1));
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user