[CK_TILE] Support f32 in FMHA (fwd and bwd) (#2836)

* Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout

Add comments with dropout implementation details

Fix performance regression of fwd+dropout

    * Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox;
    * "scalarize" seed and offset, they may come either from kernel args or from device memory
      (presumably loaded with vector loads).

    These changes help the compiler to procude more optimal code and reduce register spilling.

Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get  CWarpDstrEncoding

Use code based on BlockDropout in BlockDropoutBwd

Refactor BlockDropout (fwd)

Implement BlockDropout (fwd) for WMMA

    Originally BlockDropout only supported 32x32 tiles (IsWG32 = true),
    this version supports 16x16 tiles.
    If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly
    to BlockDropoutBwd.

Implement BlockDropoutBwd for WMMA

Remove MakeRandValLds* functions unused in BlockDropoutBwd

Remove unused Run overload from BlockDropoutBwd

* Fix regression with philox seed and offset when they exceed 32-bit int

__builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset
are 64-bit so they get truncated.

* Add F32 MFMA warp gemms

* Support f32 in fwd FMHA

* Implement transpose_vectors for 4-byte types (float)

* Fix unexpected implicit f32->uint32 cast in buffer_store<4>

__builtin_amdgcn_raw_buffer_store_b32 expects unsigned int but float was passed (implicitly casted to uint).
mbuf_t types in other buffer_store<> are changed for consistency.

* Support F32 in bwd FMHA

hdim = 256 is disabled for now because it uses too much memory on gfx90a

* Support Headdim = 48 (divisible by 16) in fwd

* Add fp32-specific receipts (800 and 801)

* Tune fwd tiles

* Tune bwd tiles

* Use small tiles only for small seqlen_q

* Fix after rebasing

* Fix selection of a fallback tile based on bm0

The assumption that the largest bm0 == 128 is not always true for
current fp32 tiles.

* Remove constraints and adjust filtering for fp32

Custom constraints are no longer needed because now the smallest tile
is selected automtically based on seqlen_q.
Filters related to qr_async_trload disabled valid fp32 tiles.

* Add fp32 tests

* Make splitkv and appendkv compile for fp32 only

There are no instances yet, but API still must compile when only fp32 is
requested.

* Remove unimportant f32 instances

* Add test_ck_tile_fmha_*_fp32 to REGRESSION_TESTS

* Replace magic numbers with a constant, improve comments for dropout

* Update changelog

* Fix condition that dq_acc must be set to zero when mask is used

The change was introduced in #2799

* Replace warp_uniform with recently added amd_wave_read_first_lane

* Add hdim = 96 and 192 to fwd

[ROCm/composable_kernel commit: 1edd250115]
This commit is contained in:
Anton Gorenko
2025-09-27 19:03:48 +06:00
committed by GitHub
parent 4158d33735
commit bc9362af55
31 changed files with 922 additions and 488 deletions

View File

@@ -32,7 +32,8 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added benchmarking support for tile engine GEMM Multi D.
* Added block scaling support in CK_TILE GEMM, allowing flexible use of quantization matrices from either A or B operands.
* Added the row-wise column-wise quantization for CK_TILE GEMM & CK_TILE Grouped GEMM.
* Added tensor-wise quantization for CK_TILE GEMM
* Added support for f32 to FMHA (fwd/bwd).
* Added tensor-wise quantization for CK_TILE GEMM.
### Optimized

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp32" : "FmhaFwdFp32",
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
@@ -12,6 +13,7 @@ FWD_DTYPE_MAP = {
}
BWD_DTYPE_MAP = {
"fp32": "FmhaBwdFp32",
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
}

View File

@@ -601,6 +601,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -370,7 +370,14 @@ class FmhaBwdDQDKDVKernel:
# TODO: design a more practical way to do it
# this is current supported tile size.
def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
if dtype == 'fp32' and tr_load == 'f':
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
]
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
return [
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
@@ -865,6 +872,30 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
# fp32 only, all variations
if receipt == 800:
cond = dtype == 'fp32'
cond &= dpad == dvpad
if not cond:
continue
# fp32 only, minimal set of parameters
elif receipt == 801:
cond = dtype == 'fp32'
cond &= hdim in [64, 128]
cond &= dpad == dvpad
cond &= mode == 'batch'
cond &= bias == 'no'
cond &= dropout == 'no'
cond &= mask == 's_no'
cond &= deterministic == "f"
if not cond:
continue
else:
# Don't build fp32 by default
if dtype == 'fp32':
continue
gen_dot_do_o[t.dot_do_o_kernel] = True
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
if not t.convert_dq_kernel.disabled:

View File

@@ -25,6 +25,7 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {
32 : 32,
48 : 48,
64 : 64,
96 : 128,
128: 128,
@@ -164,7 +165,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
@@ -249,9 +250,8 @@ class FmhaFwdApiTrait:
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def seqtune(self) -> str:
if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true
def seqtune(self, max_bm0 : int) -> str:
if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/'
else:
return f'a.seqlen_q <= {self.bm0}'
@@ -386,6 +386,7 @@ class FmhaFwdApiPool:
per_hdim_case=str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load]
max_bm0 = max((t.bm0 for t in traits), default=0)
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
@@ -393,7 +394,7 @@ class FmhaFwdApiPool:
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
@@ -534,7 +535,20 @@ class KernelComponentFactory:
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
if dtype == 'fp32':
return {
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
}
elif dtype == 'fp16' or dtype == 'bf16':
return {
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
@@ -572,7 +586,13 @@ class KernelComponentFactory:
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in ['fp16', 'bf16']:
if dtype in ['fp32']:
squant = 'f'
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
elif dtype in ['fp16', 'bf16']:
squant = 'f'
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
if hdim == 256 and hdim_v == 256:
@@ -626,6 +646,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
for tile, next_tile in zip(tiles, tiles[1:]):
assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0'
for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
@@ -635,12 +657,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
continue
if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)):
# non qr_async_trload only support km0=128 tile size when hdim is not 128
# non qr_async only support kn0=128 tile size when hdim is 128
continue
if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])):
continue
if dtype != 'fp32':
if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)):
# non qr_async_trload only support km0=128 tile size when hdim is not 128
# non qr_async only support kn0=128 tile size when hdim is 128
continue
if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])):
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
@@ -710,6 +733,31 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
if not cond:
continue
# fp32 only, all variations
if receipt == 800:
cond = dtype == 'fp32'
cond &= pipeline.F_skip == 'f'
cond &= pipeline.F_logits == 'f'
if not cond:
continue
# fp32 only, minimal set of parameters
elif receipt == 801:
cond = dtype == 'fp32'
cond &= hdim in [48, 128]
cond &= mode == 'batch'
cond &= pipeline.F_bias == 'no'
cond &= pipeline.F_lse == 'f'
cond &= pipeline.F_dropout == 'f'
cond &= pipeline.F_skip == 'f'
cond &= pipeline.F_logits == 'f'
cond &= pipeline.F_mask == 's_no'
if not cond:
continue
else:
# Don't build fp32 by default
if dtype == 'fp32':
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -184,6 +184,9 @@ class FmhaFwdAppendKVApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
@dataclass
@@ -341,6 +344,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -768,6 +768,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
@@ -834,6 +841,13 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
gen.append(k)
return gen

View File

@@ -560,6 +560,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -43,7 +43,7 @@ auto create_args(int argc, char* argv[])
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("dbias", "0", "output bias gradient or not")
.insert("prec", "fp16", "data type. fp16 or bf16")
.insert("prec", "fp16", "data type. fp32/fp16/bf16")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
@@ -159,7 +159,11 @@ int main(int argc, char* argv[])
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
if(data_type == "fp32")
{
return run<FmhaBwdFp32>(arg_parser) == bwd_result::success ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<FmhaBwdFp16>(arg_parser) == bwd_result::success ? 0 : -2;
}

View File

@@ -67,7 +67,7 @@ auto create_args(int argc, char* argv[])
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
@@ -227,7 +227,11 @@ int main(int argc, char* argv[])
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
if(data_type == "fp32")
{
return run<FmhaFwdFp32>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<FmhaFwdFp16>(arg_parser) == fwd_result::success ? 0 : -2;
}

View File

@@ -35,6 +35,14 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaBwdFp32>(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{
double rtol = 1e-4;
double atol = 1e-4;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
@@ -77,7 +85,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::optional<std::string> json = std::nullopt)
{
const std::string data_type = []() {
if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdFp16>)
if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdFp32>)
return "fp32";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdFp16>)
return "fp16";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdBf16>)
return "bf16";
@@ -776,7 +786,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
// non-deterministic kernels use atomic add to write dq
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases thus we need to zero out it first
if(!deterministic || mask.type == mask_enum::no_mask)
if(!deterministic || mask.type != mask_enum::no_mask)
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -17,6 +17,10 @@
#include <utility>
#include <variant>
struct FmhaFwdFp32
{
};
struct FmhaFwdFp16
{
};
@@ -48,6 +52,22 @@ struct FmhaFwdFp8Fp32
template <typename DataType>
struct FmhaFwdTypeConfig;
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp32>
{
using QDataType = float;
using KDataType = float;
using VDataType = float;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = float; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = float;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp16>
{

View File

@@ -41,6 +41,14 @@ auto get_elimit(std::string /*init_method*/)
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaFwdFp32>(std::string /*init_method*/)
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
{
@@ -180,7 +188,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::optional<std::string> json = std::nullopt)
{
const std::string data_type = []() {
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp16>)
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
return "fp32";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp16>)
return "fp16";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdBf16>)
return "bf16";

View File

@@ -470,7 +470,7 @@ struct buffer_store<16>
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 16);
using mbuf_t = fp32x4_t;
using mbuf_t = uint32x4_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b128(
@@ -496,7 +496,7 @@ struct buffer_store<8>
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 8);
using mbuf_t = fp32x2_t;
using mbuf_t = uint32x2_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b64(
@@ -522,7 +522,7 @@ struct buffer_store<4>
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
using mbuf_t = uint32_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b32(
@@ -548,7 +548,7 @@ struct buffer_store<2>
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 2);
using mbuf_t = short;
using mbuf_t = uint16_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b16(
@@ -573,8 +573,8 @@ struct buffer_store<1>
index_t i_offset /*max 0xFFF*/,
index_t /*flag*/ = 1)
{
static_assert(sizeof(T) == 4);
using mbuf_t = float;
static_assert(sizeof(T) == 1);
using mbuf_t = uint8_t;
#if HAS_RAW_BUFFER_BUILTINS
index_t s_offset = i_offset;
__builtin_amdgcn_raw_buffer_store_b8(

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -55,7 +55,8 @@ class philox
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
const index_t idx0,
const index_t idx1) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
@@ -66,13 +67,12 @@ class philox
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
out_tmp[0] = tmp[idx0];
out_tmp[1] = tmp[idx1];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
CK_TILE_HOST_DEVICE void
get_random_4x8(uint8_t* out, const unsigned long long subsequence, const index_t idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
@@ -83,7 +83,7 @@ class philox
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[0] = tmp[idx];
}
private:

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -34,7 +34,13 @@ struct transpose_vectors
constexpr auto I3 = number<3>{};
constexpr auto I4 = number<4>{};
if constexpr(sizeof(S) == 2)
if constexpr(sizeof(S) == 4)
{
static_for<0, NY, 1>{}([&](auto iy) {
static_for<0, NX, 1>{}([&](auto ix) { vy_tuple(iy)(ix) = vx_tuple[ix][iy]; });
});
}
else if constexpr(sizeof(S) == 2)
{
static_assert((NX % 2 == 0 && NY % 2 == 0), "wrong!");

View File

@@ -33,18 +33,22 @@ reference_batched_dropout_randval(HostTensor<RandValOutputDataType>& randval_b_m
// With SFactor = 2 it becomes:
// C i: (16 * floor(GPR_num / 8) % 32) + 8 * floor(lane / 32) + (GPR_num % 8)
// C j: (lane % 32)
// See ck_tile/ops/fmha/block/block_dropout.hpp for more details.
constexpr index_t max_warp_size = 64;
constexpr index_t warp_gemm_mn = 32;
// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
constexpr index_t philox_per_tile = 64;
constexpr index_t warp_gemm_mn = 32;
const index_t rows = integer_divide_ceil(real_seqlen_q, warp_gemm_mn);
const index_t cols = integer_divide_ceil(real_seqlen_k, warp_gemm_mn);
auto f = [&](index_t i_h, index_t row, index_t col) {
uint2 rowcol = make_uint2(row, col);
for(index_t lane = 0; lane < max_warp_size; lane++)
for(index_t lane = 0; lane < philox_per_tile; lane++)
{
philox ph(drop_seed, drop_offset + (batch * nhead + i_h) * max_warp_size + lane);
const uint64_t ph_head_offset = drop_offset + (batch * nhead + i_h) * philox_per_tile;
const index_t ph_offset = lane;
philox ph(drop_seed, ph_head_offset + ph_offset);
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));

View File

@@ -1,17 +1,44 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
// BlockDropoutBwd and BlockDropout (fwd) support two warp gemm tile sizes: 32x32 (MFMA only) and
// 16x16 (MFMA and WMMA). Even if fwd and bwd use different tile sizes, generated random
// numbers will be the same, they are also the same for MFMA (on CDNA), WMMA (on RDNA), or host
// (for verification, see ck_tile/host/reference/reference_batched_dropout_randval.hpp).
//
// The (row, col) coordinate of the current 32x32 tile in the P matrix determines a subsequence of
// random numbers (ph_subsequence).
// The (batch, head, 0..63) coordinate determines an offset in the subsequence (ph_head_offset and
// ph_offset).
// This means that subsequences are non-overlapping, reproducible and independent of mask or window.
//
// There are 3 modes (all produce the same results):
// * For 32x32 MFMA tile each of 64 lanes generates 4 * 32 bits or 16 bytes, so one warp generates
// the entire 32x32 tile (64 * 16 = 32 * 32).
// * For 16x16 MFMA tile one warp generates 1/4 of the 32x32 tile ((16 * 16) / (64 * 16) = 1/4), 4
// warps generate the same 64 * 16 random bytes and each uses its own quarter. If kMPerBlock >
// MWarp * WG::kM one warp can generate two 16x16 tiles (MIterPerWarp = 2) so fewer instructions
// are needed for generating a 32x32 tile.
// * For 16x16 WMMA tile one warp generates 1/2 of the 32x32 tile ((16 * 16) / (32 * 16) = 1/2), 2
// warps generate the same 64 * 16 random bytes and each uses its own half. If kMPerBlock > MWarp *
// WG::kM one warp can generate two 16x16 tiles.
namespace detail {
// The number of Philox 4x32 results required to fill 32x32 tile of 8-bit values
constexpr index_t philox_per_tile = 64;
} // namespace detail
struct NullBlockDropout
{
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
@@ -32,7 +59,9 @@ struct BlockDropout
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_,
bool is_store_randval_)
: ph(seed, offset + (i_batch * nheads + i_head) * get_warp_size() + get_lane_id()),
: ph_seed(amd_wave_read_first_lane(seed)),
ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
detail::philox_per_tile)),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_),
is_store_randval(is_store_randval_)
@@ -46,11 +75,15 @@ struct BlockDropout
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
@@ -78,12 +111,17 @@ struct BlockDropout
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
@@ -107,33 +145,35 @@ struct BlockDropout
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t NIterPerWarp = 1;
// The tile distribution is different from the one in MakeRandValLdsShuffleTileDistribution,
// because it can combine 2 (MIterPerWarp) 16x16 subtiles for generating them at once
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<MWarp, MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
sequence<1, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
constexpr auto randval_block_inner_part_dstr_encoding = []() {
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
std::is_same_v<typename BlockGemm::CDataType, float>)
{
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
}
else
{
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
}
}();
constexpr auto randval_block_inner_part_dstr_encoding =
typename WarpGemmDispatcher<typename WG::ADataType,
typename WG::BDataType,
typename WG::CDataType,
WG::kM,
WG::kN,
WG::kK,
false,
IsWG32>::CWarpDstrEncoding{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
@@ -147,11 +187,13 @@ struct BlockDropout
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
@@ -181,14 +223,16 @@ struct BlockDropout
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// randval tile in LDS
auto randval_lds = make_tensor_view<address_space_enum::lds>(
@@ -200,42 +244,100 @@ struct BlockDropout
// register distribute
auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
auto randval_lds_read_window =
const auto randval_lds_read_window =
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
randval_lds_window.get_window_lengths(),
randval_lds_window.get_window_origin(),
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
const int start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
const index_t start_m0_idx = randval_dram_window.get_window_origin().at(number<0>{});
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
auto generate_randval = [&](auto i_m0, auto i_n0) {
// Generate random numbers
uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
if constexpr(IsWG32)
{
// Generate the whole 32x32 tile at once (each tile consists of random numbers taken
// from a separate subsequence of Philox)
const unsigned long long ph_subsequence =
bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
const index_t ph_offset = get_lane_id();
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
else
{
// Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
// MIterPerWarp is equal to 1 or 2)
const unsigned long long ph_subsequence =
bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
const index_t subtile_m0 = wg_m0 % 2;
if constexpr(get_warp_size() == 32)
{
const index_t ph_offset = (get_lane_id() & 15) +
(((get_lane_id() >> 4) & 1) << 5) +
((wg_n0 % 2) << 4);
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
if constexpr(MIterPerWarp == 1)
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
ph.get_random_8x8(
random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
}
else
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
}
else
{
const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
if constexpr(MIterPerWarp == 1)
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
ph.get_random_4x8(
random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
}
else
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
ph.get_random_8x8(
random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
}
}
}
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// Transpose randval using LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
const auto randval = load_tile(randval_lds_read_window);
block_sync_lds();
return randval;
};
if(is_store_randval)
{
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
const auto randval = generate_randval(i_m0, i_n0);
// save to Global
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
@@ -244,37 +346,21 @@ struct BlockDropout
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
});
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
};
}
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
const auto randval = generate_randval(i_m0, i_n0);
// Drop values of P based on the generated probabilities
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
constexpr auto p_idx0 =
tile_distributed_index<i_m0 * MIterPerWarp +
idx0.impl_.template at<0>()>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
tile_distributed_index<i_n0,
idx1.impl_.template at<1>(),
idx1.impl_.template at<2>()>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
@@ -286,12 +372,15 @@ struct BlockDropout
});
}
ck_tile::philox ph;
const unsigned long long ph_seed;
const unsigned long long ph_head_offset;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
// TODO: IsWG32_ is not needed as template parameter and can be removed. IsDropout_ == false can be
// replaced with NullBlockDropout. This requires changes in xformers and other libs.
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd;
@@ -301,8 +390,8 @@ struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
static constexpr bool IsDropout = false;
static constexpr bool IsStoreRandval = IsStoreRandval_;
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto
template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
@@ -316,10 +405,7 @@ struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = true;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static constexpr bool IsWG32 = IsWG32_;
static constexpr bool IsDropout = true;
static constexpr bool IsStoreRandval = IsStoreRandval_;
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
@@ -329,38 +415,30 @@ struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
unsigned long long offset,
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_)
: ph(seed,
offset + (i_batch * nheads + i_head) * get_warp_size() +
(IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
: ph_seed(amd_wave_read_first_lane(seed)),
ph_head_offset(amd_wave_read_first_lane(offset + (i_batch * nheads + i_head) *
detail::philox_per_tile)),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
{
}
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
template <typename BlockGemm, bool IsFwd = false, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
@@ -384,85 +462,39 @@ struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
number<kN1>{},
number<1>{});
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
randval_lds_block_desc_0,
ck_tile::make_tuple(
make_pass_through_transform(number<kMPerStep>{}),
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
return randval_lds_block_desc;
}
template <typename BlockGemm, bool IsFwd = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t MIterPerWarp = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return 2;
}
else
{
return 1;
}
}();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<MWarp, MIterPerWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
sequence<1, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values are consistent with Bwd.
// except headdim256.
constexpr auto randval_block_inner_part_dstr_encoding = []() {
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
std::is_same_v<typename BlockGemm::CDataType, float>)
{
if constexpr(IsWG32)
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{};
}
else
{
if constexpr(IsWG32)
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{};
}
}();
constexpr auto randval_block_inner_part_dstr_encoding =
typename WarpGemmDispatcher<typename WG::ADataType,
typename WG::BDataType,
typename WG::CDataType,
WG::kM,
WG::kN,
WG::kK,
false,
IsWG32>::CWarpDstrEncoding{};
static_assert(
std::is_same_v<remove_cvref_t<decltype(randval_block_inner_part_dstr_encoding)>,
typename WG::CWarpDstrEncoding>);
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
@@ -471,129 +503,6 @@ struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
typename WG::CWarpDstrEncoding{});
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm,
typename PComputeDataType,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// randval tile in LDS
auto randval_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
auto randval_lds_window = make_tile_window(
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
// register distribute
auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
auto randval_lds_read_window =
make_tile_window(randval_lds_window.get_bottom_tensor_view(),
randval_lds_window.get_window_lengths(),
randval_lds_window.get_window_origin(),
MakeRandValLdsShuffleTileDistribution<BlockGemm>());
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
? p_compute[p_idx] * rp_undrop
: PComputeDataType(0);
});
});
// save to Global
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
}
template <typename BlockGemm,
typename RandValOutputDataType,
typename PComputeWindow,
@@ -605,92 +514,111 @@ struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr bool IsWG32 = WG::kM == 32;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr index_t MIterPerWarp = (!IsWG32 && kMPerBlock > MWarp * WG::kM) ? 2 : 1;
constexpr index_t kMPerStep = MIterPerWarp * MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
const index_t iMWarp = get_warp_id() / NWarp;
const index_t iNWarp = get_warp_id() % NWarp;
auto generate_randval = [&](auto i_m0, auto i_n0) {
// Generate random numbers
uint8_t random_uint8_t[randval_dist_generated.kThreadElementSpaceSize];
const index_t wg_m0 = (start_m0_idx / WG::kM) + (i_m0 * MWarp + iMWarp) * MIterPerWarp;
const index_t wg_n0 = (start_n0_idx / WG::kN) + (i_n0 * NWarp + iNWarp);
if constexpr(IsWG32)
{
return MWarp * WG::kM * 2;
// Generate the whole 32x32 tile at once (each tile consists of random numbers
// taken from a separate subsequence of Philox)
const unsigned long long ph_subsequence =
bit_cast<unsigned long long>(make_uint2(wg_m0, wg_n0));
const index_t ph_offset = get_lane_id();
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
else
{
return MWarp * WG::kM;
// Generate one or two 16x16 subtiles of the 32x32 tile (depending on whether
// MIterPerWarp is equal to 1 or 2)
const unsigned long long ph_subsequence =
bit_cast<unsigned long long>(make_uint2(wg_m0 / 2, wg_n0 / 2));
const index_t subtile_m0 = wg_m0 % 2;
if constexpr(get_warp_size() == 32)
{
const index_t ph_offset = (get_lane_id() & 15) +
(((get_lane_id() >> 4) & 1) << 5) +
((wg_n0 % 2) << 4);
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
if constexpr(MIterPerWarp == 1)
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
ph.get_random_8x8(
random_uint8_t, ph_subsequence, subtile_m0 * 2 + 0, subtile_m0 * 2 + 1);
}
else
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
ph.get_random_16x8(random_uint8_t, ph_subsequence);
}
}
else
{
const index_t subtile_n0 = (get_lane_id() >> 4) & 1;
const index_t ph_offset = (get_lane_id() & 47) + ((wg_n0 % 2) << 4);
const ck_tile::philox ph(ph_seed, ph_head_offset + ph_offset);
if constexpr(MIterPerWarp == 1)
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 4);
ph.get_random_4x8(
random_uint8_t, ph_subsequence, subtile_m0 * 2 + subtile_n0);
}
else
{
static_assert(randval_dist_generated.kThreadElementSpaceSize == 8);
ph.get_random_8x8(
random_uint8_t, ph_subsequence, 0 * 2 + subtile_n0, 1 * 2 + subtile_n0);
}
}
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval = make_static_distributed_tensor<uint8_t>(
MakeRandValTileDistribution<BlockGemm, false>());
if constexpr(IsWG32)
static_assert(randval.kThreadElementSpaceSize == 16);
else
static_assert(randval.kThreadElementSpaceSize == 4 ||
randval.kThreadElementSpaceSize == 8);
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
return randval_dist_generated;
};
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start, block_col_start;
if constexpr(IsWG32)
{
block_row_start = (start_m0_idx / WG::kM) + i_m0;
block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
}
else
{
block_row_start = start_m0_idx / 32 + i_m0;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
}
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t* random_uint8_t_;
if constexpr(MBwdWG16SingleIterCheck)
{
uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const index_t start_idx =
((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
ph.get_random_4x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else if constexpr(MBwdWG16MultiIterCheck)
{
uint8_t random_uint8_t[8];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const index_t start_idx = (get_lane_id() >> 4) & 1;
ph.get_random_8x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else
{
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
random_uint8_t_ = random_uint8_t;
}
const auto randval = generate_randval(i_m0, i_n0);
// Drop values of P based on the generated probabilities, negative sign is used to
// distinguish such values later in bwd pipeline.
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
idx0.impl_.at(1),
idx0.impl_.at(2)>{};
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
constexpr auto p_idx0 =
tile_distributed_index<i_m0 * MIterPerWarp +
idx0.impl_.template at<0>(),
idx0.impl_.template at<1>(),
idx0.impl_.template at<2>()>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
@@ -717,7 +645,8 @@ struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
}
}
ck_tile::philox ph;
const unsigned long long ph_seed;
const unsigned long long ph_head_offset;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
};

View File

@@ -82,6 +82,7 @@ struct FmhaBwdDQDKDVKernel
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
// clang-format on
@@ -1187,6 +1188,7 @@ struct FmhaBwdOGradDotOKernel
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
// clang-format on
@@ -1443,6 +1445,7 @@ struct FmhaBwdConvertQGradKernel
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
// clang-format on

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -32,12 +32,27 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
std::is_same_v<typename Problem::KDataType, float> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 16);
return WarpGemmDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
@@ -49,6 +64,8 @@ struct BlockFmhaFwdPagedKVPipelineQRKSVSDefaultPolicy
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -264,12 +264,27 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
std::is_same_v<typename Problem::KDataType, float> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 16);
return WarpGemmDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
@@ -281,6 +296,8 @@ struct BlockFmhaPipelineQRKSVSWholeKPrefetchDefaultPolicy
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)

View File

@@ -73,12 +73,27 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
constexpr auto warp_gemm = []() {
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
std::is_same_v<typename Problem::KDataType, float> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 16);
return WarpGemmDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
@@ -90,6 +105,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
@@ -201,7 +218,7 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
constexpr auto q_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kMPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kMPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<8>{},
number<kKPack>{},
number<1>{});
constexpr auto q_lds_block_desc = transform_tensor_descriptor(
@@ -228,14 +245,29 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
constexpr index_t WarpGemmM = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
if constexpr(std::is_same_v<typename Problem::QDataType, float> &&
std::is_same_v<typename Problem::KDataType, float> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 16);
return WarpGemmDispatcher<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::SaccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
true>{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
@@ -247,6 +279,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 4 || WarpGemmM == 16 || WarpGemmM == 32);
if constexpr(WarpGemmM == 32)
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution{};
else if constexpr(WarpGemmM == 16)
@@ -258,6 +292,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
std::is_same_v<typename Problem::KDataType, fp8_t> &&
std::is_same_v<typename Problem::SaccDataType, float>)
{
static_assert(WarpGemmM == 32);
// TODO: hard coded here. Otherwise, it may incorrect result
constexpr index_t swizzle_factor = 4;
return WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution<
@@ -507,7 +543,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr auto k_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / kKPack>{}, number<kNPerBlock>{}, number<kKPack>{}),
make_tuple(number<(kNPerBlock + 1) * kKPack>{}, number<kKPack>{}, number<1>{}),
number<8>{},
number<kKPack>{},
number<1>{});
constexpr auto k_lds_block_desc = transform_tensor_descriptor(
@@ -806,15 +842,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
constexpr index_t kNPack = 32;
static_assert(kNPerBlock % kNPack == 0);
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
static_assert(kNPerBlock % 16 == 0);
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
@@ -824,7 +859,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
sequence<1, 2, 1>, // N0 K2 N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
else if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
@@ -863,13 +898,40 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0);
return make_static_tile_distribution(
constexpr auto dstr = make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K0
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 2>, // N0 K1
sequence<0, 1>>{});
if constexpr(container_reduce(dstr.get_lengths(), std::multiplies<index_t>{}, 1) ==
kNPerBlock * kKPerBlock)
{
return dstr;
}
else
{
static_assert(kKPerBlock % 16 == 0);
constexpr index_t kKPerIter = kKPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0_m = kKPerBlock / kKPerIter;
constexpr index_t K2 = 2;
constexpr index_t K1_m = kKPerIter / K2;
constexpr index_t N2_m = get_warp_size() / K1_m;
constexpr index_t N0_m = kNPerBlock / (N2_m * N1);
constexpr auto dstr_m = make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<N0_m, N1, N2_m>, sequence<K0_m, K1_m, K2>>,
tuple<sequence<1>, sequence<1, 2>>, // N1, N2 K1
tuple<sequence<1>, sequence<2, 1>>,
sequence<2, 1, 2>, // K0 N0 K2
sequence<0, 0, 2>>{});
static_assert(container_reduce(dstr_m.get_lengths(),
std::multiplies<index_t>{},
1) == kNPerBlock * kKPerBlock);
return dstr_m;
}
}
}
@@ -897,14 +959,14 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
if constexpr(total_pixels % N1 != 0 || kKPack % K3 != 0) // if K2 or K3 is not divisible
{
constexpr index_t kNPack = 32;
static_assert(kNPerBlock % kNPack == 0);
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
static_assert(kNPerBlock % 16 == 0);
constexpr index_t kNPack = kNPerBlock % 32 == 0 ? 32 : 16;
constexpr index_t K0 = kBlockSize / get_warp_size();
constexpr index_t N2 = 2;
constexpr index_t N1_m = kNPack / N2;
constexpr index_t N0_m = kNPerBlock / kNPack;
constexpr index_t K1 = get_warp_size() / N1_m;
constexpr index_t K2_m = kKPerBlock / K1 / K0;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0_m, N1_m, N2>, sequence<K0, K1, K2_m>>,
@@ -913,7 +975,7 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
sequence<1, 1, 2>, // N0 K2 <-> N2
sequence<0, 2, 2>>{});
}
else if constexpr(get_warp_size() % (kKPack / K3 * N0) == 0)
else if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();

View File

@@ -7,20 +7,22 @@
namespace ck_tile {
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length(index_t len)
template <index_t Headdim>
static CK_TILE_HOST_DEVICE constexpr index_t ceil_to_qualified_tile_length()
{
if(len == 96)
if constexpr(Headdim == 48)
return 48;
else if constexpr(Headdim == 96)
return 128;
if(len == 160)
else if constexpr(Headdim == 160)
return 256;
if(len == 192)
else if constexpr(Headdim == 192)
return 192;
// only length of 96, 160 and power-of-two is supported
if(!(len & (len - 1)))
return len;
return 0;
else if constexpr(is_power_of_two_integer(Headdim))
return Headdim;
else
static_assert(Headdim == 0,
"only Headdim of 48, 96, 160, 192 and power-of-two is supported");
};
template <typename BlockTile_, // sequence<...
@@ -55,7 +57,7 @@ struct TileFmhaShape
// once (or repeately load Q as a whole tile)
static_assert(kQKHeaddim % kK0 == 0, "kQKHeaddim should be divisible by kK0");
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length(kQKHeaddim);
static constexpr index_t kSubQKHeaddim = ceil_to_qualified_tile_length<kQKHeaddim>();
// v, rowmajor : seqlen*hdim, colmajor : hdim*seqlen
static constexpr bool IsVLayoutRowMajor = IsVLayoutRowMajor_;

View File

@@ -12,6 +12,24 @@
namespace ck_tile {
// fp32
using WarpGemmMfmaF32F32F32M16N16K4 = WarpGemmImpl<
WarpGemmAttributeMfma<WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K16 = WarpGemmImpl<WarpGemmAttributeMfmaIterateK<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
template <WGAttrNumAccessEnum AttrNumAccess = WGAttrNumAccessEnum::Single>
using WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAttributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF32F32F32M16N16K4<WGAttrCtlEnum::Default_>,
4,
AttrNumAccess>>;
// fp16
using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -61,6 +61,135 @@ enum class WGAttrCtlEnum
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// F32
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF32F32F32M16N16K4
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AVecType = ext_vector_t<ADataType, 1>;
using BVecType = ext_vector_t<BDataType, 1>;
using CVecType = ext_vector_t<CDataType, 4>;
static constexpr index_t kM = 16;
static constexpr index_t kN = 16;
static constexpr index_t kK = 4;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 16;
static constexpr index_t kBNLane = 16;
static constexpr index_t kABKLane = 4;
static constexpr index_t kABKPerLane = 1;
static constexpr index_t kCMLane = 4;
static constexpr index_t kCNLane = 16;
static constexpr index_t kCM0PerLane = 1;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x4f32", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x4f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF32F32F32M32N32K2
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = float;
using BDataType = float;
using CDataType = float;
using AVecType = ext_vector_t<ADataType, 1>;
using BVecType = ext_vector_t<BDataType, 1>;
using CVecType = ext_vector_t<CDataType, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 2;
static constexpr index_t kAMBlock = 1;
static constexpr index_t kBNBlock = 1;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 1;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x2f32", Ctrl)
else
{
#if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], c_vec, 0, 0, 0);
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
#if defined(__gfx9__)
return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x2f32(a_vec[0], b_vec[0], CVecType{0.f}, 0, 0, 0));
#else
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
return CVecType{0.f};
#endif
}
};
// V_MFMA_F32_16x16x32_BF16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K32

View File

@@ -23,6 +23,11 @@ template <typename AType,
struct WarpGemmDispatcher;
// clang-format off
// fp32
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmDispatcher<float, float, float, 16, 16, 4, false> { using Type = WarpGemmMfmaF32F32F32M16N16K4; };
template<> struct WarpGemmDispatcher<float, float, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF32F32F32M16N16K16<>; };
template<> struct WarpGemmDispatcher<float, float, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF32F32F32M16N16K16TransposedCDistribution<>; };
// fp16
// ADataType, BDataType, AccDataType, MPerWave, NPerWave, KPerWave, TransposeC, SwizzleA, UseStructuredSparsity
template<> struct WarpGemmDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };

View File

@@ -38,8 +38,10 @@ set(REGRESSION_TESTS
test_conv_tensor_rearrange
test_gemm_mx
test_ck_tile_batched_transpose
test_ck_tile_fmha_bwd_fp32
test_ck_tile_fmha_bwd_bf16
test_ck_tile_fmha_bwd_fp16
test_ck_tile_fmha_fwd_fp32
test_ck_tile_fmha_fwd_bf16
test_ck_tile_fmha_fwd_fp16
test_ck_tile_fmha_fwd_fp8

View File

@@ -6,12 +6,18 @@ endif()
set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
add_gtest_executable(test_ck_tile_fmha_bwd_fp32 test_fmha_bwd_fp32.cpp)
target_link_libraries(test_ck_tile_fmha_bwd_fp32 PRIVATE ${FMHA_BWD_INSTANCES})
add_gtest_executable(test_ck_tile_fmha_bwd_bf16 test_fmha_bwd_bf16.cpp)
target_link_libraries(test_ck_tile_fmha_bwd_bf16 PRIVATE ${FMHA_BWD_INSTANCES})
add_gtest_executable(test_ck_tile_fmha_bwd_fp16 test_fmha_bwd_fp16.cpp)
target_link_libraries(test_ck_tile_fmha_bwd_fp16 PRIVATE ${FMHA_BWD_INSTANCES})
add_gtest_executable(test_ck_tile_fmha_fwd_fp32 test_fmha_fwd_fp32.cpp)
target_link_libraries(test_ck_tile_fmha_fwd_fp32 PRIVATE ${FMHA_FWD_INSTANCES})
add_gtest_executable(test_ck_tile_fmha_fwd_bf16 test_fmha_fwd_bf16.cpp)
target_link_libraries(test_ck_tile_fmha_fwd_bf16 PRIVATE ${FMHA_FWD_INSTANCES})
@@ -23,8 +29,10 @@ target_link_libraries(test_ck_tile_fmha_fwd_fp8 PRIVATE ${FMHA_FWD_INSTANCES})
add_custom_target(test_ck_tile_fmha
DEPENDS
test_ck_tile_fmha_bwd_fp32
test_ck_tile_fmha_bwd_bf16
test_ck_tile_fmha_bwd_fp16
test_ck_tile_fmha_fwd_fp32
test_ck_tile_fmha_fwd_bf16
test_ck_tile_fmha_fwd_fp16
test_ck_tile_fmha_fwd_fp8

View File

@@ -0,0 +1,20 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "example/ck_tile/01_fmha/fmha_bwd.hpp"
#include "example/ck_tile/01_fmha/fmha_bwd_runner.hpp"
#include "gtest/gtest.h"
using DataTypeConfig = FmhaBwdFp32;
using ::testing::Values;
using ::testing::ValuesIn;
const auto HDimValues = Values(std::tuple{32, -1}, std::tuple{64, -1}, std::tuple{128, -1});
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
constexpr std::string init_method = "uf";
#include "test_fmha_bwd.inc"

View File

@@ -515,6 +515,8 @@ class PagedKV : public TestWithParam<std::tuple<std::tuple<int, int>,
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(PagedKV);
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
PagedKV,
Combine(SplitKVHDimValues,
@@ -580,6 +582,8 @@ class SplitKV : public TestWithParam<std::tuple<std::tuple<int, int>,
{
};
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(SplitKV);
INSTANTIATE_TEST_SUITE_P(TestCkTileFmhaFwd,
SplitKV,
Combine(SplitKVHDimValues,
@@ -662,6 +666,8 @@ INSTANTIATE_TEST_SUITE_P(
std::tuple{2, 3, 1, 264, 265, "1"},
std::tuple{4, 4, 2, 71, 64, "1"})));
GTEST_ALLOW_UNINSTANTIATED_PARAMETERIZED_TEST(AppendKV);
TEST_P(AppendKV, Test)
{
auto [hdims,

View File

@@ -0,0 +1,39 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "example/ck_tile/01_fmha/fmha_fwd.hpp"
#include "example/ck_tile/01_fmha/fmha_fwd_runner.hpp"
#include "gtest/gtest.h"
#include <tuple>
#include <string>
using ::testing::Values;
using DataTypeConfig = FmhaFwdFp32;
const auto HDimValues = Values(std::tuple{32, -1},
std::tuple{48, -1},
std::tuple{64, -1},
std::tuple{96, 128},
std::tuple{128, -1},
std::tuple{192, -1},
std::tuple{256, -1});
const auto SplitKVHDimValues = Values();
const auto AppendKVHDimValues = Values();
const auto ModeValues = Values(mode_enum::batch, mode_enum::group);
const auto IsVRowmajorValues = Values(true);
const bool squant = false;
const std::string init_method = "uf";
const bool def_lse = true;
const bool def_is_v_rowmajor = true;
int adjust_seqlen(int seqlen) { return seqlen; }
#include "test_fmha_fwd.inc"