[CK_TILE] Support f32 in FMHA (fwd and bwd) (#2836)

* Support 16x16 (MFMA, WMMA) and 32x32 (MFMA) tiles in fwd and bwd BlockDropout

Add comments with dropout implementation details

Fix performance regression of fwd+dropout

    * Remove some usage of type punning (reinterpret_cast with ref or ptr) in Philox;
    * "scalarize" seed and offset, they may come either from kernel args or from device memory
      (presumably loaded with vector loads).

    These changes help the compiler to procude more optimal code and reduce register spilling.

Use WarpGemmDispatcher instead of explicit WarpGemmMfma... to get  CWarpDstrEncoding

Use code based on BlockDropout in BlockDropoutBwd

Refactor BlockDropout (fwd)

Implement BlockDropout (fwd) for WMMA

    Originally BlockDropout only supported 32x32 tiles (IsWG32 = true),
    this version supports 16x16 tiles.
    If MPerBlock > MWarp * 16, it can generate numbers for two 16x16 tiles, similarly
    to BlockDropoutBwd.

Implement BlockDropoutBwd for WMMA

Remove MakeRandValLds* functions unused in BlockDropoutBwd

Remove unused Run overload from BlockDropoutBwd

* Fix regression with philox seed and offset when they exceed 32-bit int

__builtin_amdgcn_readfirstlane works with 32-bit values, seed and offset
are 64-bit so they get truncated.

* Add F32 MFMA warp gemms

* Support f32 in fwd FMHA

* Implement transpose_vectors for 4-byte types (float)

* Fix unexpected implicit f32->uint32 cast in buffer_store<4>

__builtin_amdgcn_raw_buffer_store_b32 expects unsigned int but float was passed (implicitly casted to uint).
mbuf_t types in other buffer_store<> are changed for consistency.

* Support F32 in bwd FMHA

hdim = 256 is disabled for now because it uses too much memory on gfx90a

* Support Headdim = 48 (divisible by 16) in fwd

* Add fp32-specific receipts (800 and 801)

* Tune fwd tiles

* Tune bwd tiles

* Use small tiles only for small seqlen_q

* Fix after rebasing

* Fix selection of a fallback tile based on bm0

The assumption that the largest bm0 == 128 is not always true for
current fp32 tiles.

* Remove constraints and adjust filtering for fp32

Custom constraints are no longer needed because now the smallest tile
is selected automtically based on seqlen_q.
Filters related to qr_async_trload disabled valid fp32 tiles.

* Add fp32 tests

* Make splitkv and appendkv compile for fp32 only

There are no instances yet, but API still must compile when only fp32 is
requested.

* Remove unimportant f32 instances

* Add test_ck_tile_fmha_*_fp32 to REGRESSION_TESTS

* Replace magic numbers with a constant, improve comments for dropout

* Update changelog

* Fix condition that dq_acc must be set to zero when mask is used

The change was introduced in #2799

* Replace warp_uniform with recently added amd_wave_read_first_lane

* Add hdim = 96 and 192 to fwd
This commit is contained in:
Anton Gorenko
2025-09-27 19:03:48 +06:00
committed by GitHub
parent c6bfd97c2d
commit 1edd250115
31 changed files with 922 additions and 488 deletions

View File

@@ -1,8 +1,9 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
# generate kernel instances to speed up compilation
FWD_DTYPE_MAP = {
"fp32" : "FmhaFwdFp32",
"fp16" : "FmhaFwdFp16",
"bf16" : "FmhaFwdBf16",
"fp8" : "FmhaFwdFp8",
@@ -12,6 +13,7 @@ FWD_DTYPE_MAP = {
}
BWD_DTYPE_MAP = {
"fp32": "FmhaBwdFp32",
"fp16": "FmhaBwdFp16",
"bf16": "FmhaBwdBf16"
}

View File

@@ -601,6 +601,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -370,7 +370,14 @@ class FmhaBwdDQDKDVKernel:
# TODO: design a more practical way to do it
# this is current supported tile size.
def get_dq_dk_dv_tiles(dtype : str, tr_load: str) -> List[FmhaBwdDQDKDVTileSize]:
if (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
if dtype == 'fp32' and tr_load == 'f':
return [
# bm0, bn0, bk0, bk1, bk2, bk3, bk4, bhdq, bhdv,
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 64, 16, 64, 16, 16, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 16, 64, 128, 16, 128, 16, 16, 128, 128, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 16, 16, 16, 16, 1),
]
elif (dtype == 'fp16' or dtype == 'bf16') and tr_load == 'f':
return [
FmhaBwdDQDKDVTileSize( 32, 128, 32, 32, 32, 32, 64, 32, 32, 1, 4, 1, 4, 1, 1, 2, 2, 1, 16, 16, 32, 16, 16, 16, 1),
FmhaBwdDQDKDVTileSize( 32, 128, 64, 32, 64, 32, 32, 64, 64, 1, 4, 1, 4, 1, 1, 1, 4, 1, 16, 16, 32, 16, 16, 16, 1),
@@ -865,6 +872,30 @@ def get_bwd_blobs(filter_list: str, receipt, mask_impl, optdim_list) -> Tuple[Fm
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
# fp32 only, all variations
if receipt == 800:
cond = dtype == 'fp32'
cond &= dpad == dvpad
if not cond:
continue
# fp32 only, minimal set of parameters
elif receipt == 801:
cond = dtype == 'fp32'
cond &= hdim in [64, 128]
cond &= dpad == dvpad
cond &= mode == 'batch'
cond &= bias == 'no'
cond &= dropout == 'no'
cond &= mask == 's_no'
cond &= deterministic == "f"
if not cond:
continue
else:
# Don't build fp32 by default
if dtype == 'fp32':
continue
gen_dot_do_o[t.dot_do_o_kernel] = True
gen_dq_dk_dv[t.dq_dk_dv_kernel] = True
if not t.convert_dq_kernel.disabled:

View File

@@ -25,6 +25,7 @@ DTYPE_BITS = {
K0_MAX_SUBMAX_MAP = {
32 : 32,
48 : 48,
64 : 64,
96 : 128,
128: 128,
@@ -164,7 +165,7 @@ float fmha_fwd(fmha_fwd_traits t, fmha_fwd_args a, const ck_tile::stream_config&
return get_num_thread_blocks(a.batch, a.nhead_q, a.max_seqlen_q, kM0);
}};
const bool has_load_tr = ck_tile::is_load_tr_supported();
[[maybe_unused]] const bool has_load_tr = ck_tile::is_load_tr_supported();
{F_dispatch}
return r;
@@ -249,9 +250,8 @@ class FmhaFwdApiTrait:
else : return f'a.seqlen_q % {self.bm0} == 0'
else: assert False
@property
def seqtune(self) -> str:
if self.bm0 == 128: return 'true/*fall back to largest tile*/' # group mode only generate spad/skpad == true
def seqtune(self, max_bm0 : int) -> str:
if self.bm0 == max_bm0: return 'true/*fall back to largest tile*/'
else:
return f'a.seqlen_q <= {self.bm0}'
@@ -386,6 +386,7 @@ class FmhaFwdApiPool:
per_hdim_case=str()
for j, (hdim, hdim_v) in enumerate(self.pool[dtype].keys()):
traits=[t for t in self.pool[dtype][(hdim, hdim_v)] if tr_load == t.tr_load]
max_bm0 = max((t.bm0 for t in traits), default=0)
inners=str()
for k, trait in enumerate(traits):
if_k = 'if' if k == 0 else 'else if'
@@ -393,7 +394,7 @@ class FmhaFwdApiPool:
F_pipeline_enum=PIPELINE_ENUM_MAP[trait.pipeline_tag], F_logits=BOOL_MAP[trait.logits], F_mask=get_mask_map(self.mask_impl)[trait.mask],
F_mask_check=get_mask_check_map(self.mask_impl)[trait.mask], F_bias_check=BIAS_CHECK_MAP[trait.bias], F_bias=BIAS_MAP[trait.bias],
F_lse=BOOL_MAP[trait.lse], F_dropout=BOOL_MAP[trait.dropout], F_skip=BOOL_MAP[trait.skip], F_trload=BOOL_MAP[trait.tr_load],
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune, F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_squant=BOOL_MAP[trait.squant], F_scheck=trait.scheck, F_seqtune=trait.seqtune(max_bm0), F_skcheck=trait.skcheck, F_dcheck=trait.dcheck, F_dvcheck=trait.dvcheck,
F_constraint=trait.constraint,
F_spad=BOOL_MAP[trait.spad], F_skpad=BOOL_MAP[trait.skpad], F_dpad=BOOL_MAP[trait.dpad], F_dvpad=BOOL_MAP[trait.dvpad],
F_bm0=trait.bm0, F_bn0=trait.bn0, F_bk0=trait.bk0, F_bn1=trait.bn1, F_bk1=trait.bk1, F_bk0max=trait.bk0max,
@@ -534,7 +535,20 @@ class KernelComponentFactory:
# this is current supported tile size per hdim
@staticmethod
def get_hdim_tile_size_dict(dtype : str) -> Optional[dict]:
if dtype == 'fp16' or dtype == 'bf16':
if dtype == 'fp32':
return {
# bm0, bn0, bk0, bn1, bk1,
( 32, 32) : [FmhaFwdTileSize( 64, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 48, 48) : [FmhaFwdTileSize( 32, 128, 16, 48, 16, 48, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 64, 16, 48, 32, 48, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 64, 64) : [FmhaFwdTileSize( 64, 64, 32, 64, 32, 64, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
( 96, 128) : [FmhaFwdTileSize(128, 64, 32, 128, 32, 96, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(128, 128) : [FmhaFwdTileSize( 32, 128, 32, 128, 16, 128, 2, 1, 1, 2, 1, 1, 16, 16, 16, 16, 16, 16, -1),
FmhaFwdTileSize(128, 64, 32, 128, 32, 128, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(192, 192) : [FmhaFwdTileSize( 64, 64, 32, 192, 32, 192, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
(256, 256) : [FmhaFwdTileSize( 64, 64, 32, 256, 32, 256, 4, 1, 1, 4, 1, 1, 16, 16, 16, 16, 16, 16, -1)],
}
elif dtype == 'fp16' or dtype == 'bf16':
return {
(32, 32) : [FmhaFwdTileSize(128, 64, 16, 32, 32, 32, 4, 1, 1, 4, 1, 1, 32, 32, 16, 32, 32, 16, -1)],
(64, 64) : [FmhaFwdTileSize(16, 32, 64, 64, 32, 64, 1, 1, 1, 1, 1, 1, 16, 16, 32, 16, 16, 32, -1),
@@ -572,7 +586,13 @@ class KernelComponentFactory:
# TODO: currently for qr pipeline, let 't' padding to appear later!!
# TODO: how to design this more generic?
pipelines = []
if dtype in ['fp16', 'bf16']:
if dtype in ['fp32']:
squant = 'f'
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 'f', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 'f', 't', 'f', 'f', logits, bias, lse, dropout, squant, mask, skip, 'f'))
pipelines.append(FmhaFwdPipeline('qr', 'row', 't', 't', 't', 't', logits, bias, lse, dropout, squant, mask, skip, 'f'))
elif dtype in ['fp16', 'bf16']:
squant = 'f'
for logits, mask, bias, lse, dropout, skip in itertools.product(["t", "f"], get_mask_map(mask_impl).keys(), BIAS_MAP.keys(), ["t", "f"], ["t", "f"], ["t", "f"]):
if hdim == 256 and hdim_v == 256:
@@ -626,6 +646,8 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
continue
#for hdim_str, mode, mask, bias, lse in itertools.product(d.keys(), MODE_MAP.keys(), MASK_MAP.keys(), ["t", "f"], ["t", "f"]):
for ((hdim, hdim_v), tiles), mode in itertools.product(d.items(), MODE_MAP.keys()):
for tile, next_tile in zip(tiles, tiles[1:]):
assert next_tile.F_bm0 >= tile.F_bm0, 'Tiles must be ordered by increasing bm0'
for tile, pipeline in itertools.product(tiles, factory.get_pipelines(dtype, hdim, hdim_v, receipt, mask_impl)):
if mode == "group":
if pipeline.F_spad != 't' or pipeline.F_skpad != 't':
@@ -635,12 +657,13 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
# NOTE: this is used to speedup deepseek prefill case, we don't gen training
if pipeline.F_bias != 'no' or pipeline.F_dropout == 't':
continue
if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)):
# non qr_async_trload only support km0=128 tile size when hdim is not 128
# non qr_async only support kn0=128 tile size when hdim is 128
continue
if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])):
continue
if dtype != 'fp32':
if pipeline.tag != 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 != 128) or ((hdim, hdim_v) != (128, 128) and tile.F_bm0 != 128)):
# non qr_async_trload only support km0=128 tile size when hdim is not 128
# non qr_async only support kn0=128 tile size when hdim is 128
continue
if pipeline.tag == 'qr_async_trload' and (((hdim, hdim_v) == (128, 128) and tile.F_bn0 == 128) or ((hdim, hdim_v) not in [(64, 64), (128, 128)])):
continue
# logits_soft_cap is only allowed if no bias
if not ((pipeline.F_logits == 't' and pipeline.F_bias == 'no') or pipeline.F_logits == 'f'):
continue
@@ -710,6 +733,31 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
if not cond:
continue
# fp32 only, all variations
if receipt == 800:
cond = dtype == 'fp32'
cond &= pipeline.F_skip == 'f'
cond &= pipeline.F_logits == 'f'
if not cond:
continue
# fp32 only, minimal set of parameters
elif receipt == 801:
cond = dtype == 'fp32'
cond &= hdim in [48, 128]
cond &= mode == 'batch'
cond &= pipeline.F_bias == 'no'
cond &= pipeline.F_lse == 'f'
cond &= pipeline.F_dropout == 'f'
cond &= pipeline.F_skip == 'f'
cond &= pipeline.F_logits == 'f'
cond &= pipeline.F_mask == 's_no'
if not cond:
continue
else:
# Don't build fp32 by default
if dtype == 'fp32':
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -184,6 +184,9 @@ class FmhaFwdAppendKVApiPool:
per_hdim_case = per_hdim_case + FMHA_FWD_API_PER_HDIM_CASE.format(F_if=if_j, F_hdim=hdim, F_hdim_v=hdim, F_inner_dispatch=inners)
if_i = 'if' if i == 0 else 'else if'
per_dtypes = per_dtypes + FMHA_FWD_API_PER_DTYPE.format(F_if=if_i, F_dtype=dtype, F_hdim_case=per_hdim_case)
if not per_dtypes:
# empty string we add some ignore to suppress warning in api
per_dtypes += ' (void)t ; (void)s ; (void)a;'
return FMHA_FWD_KERNEL_HEADER + FMHA_FWD_APPENDKV_API.format(F_dispatch = per_dtypes)
@dataclass
@@ -341,6 +344,13 @@ def get_fwd_appendkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, op
cond &= pipeline.F_vlayout == 'row'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -768,6 +768,13 @@ def get_fwd_splitkv_blobs(kernel_filter : Optional[str], receipt, mask_impl, opt
cond &= pipeline.F_squant == 'f'
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)
@@ -834,6 +841,13 @@ def get_fwd_splitkv_combine_blobs(kernel_filter : Optional[str], receipt, optdim
cond = dtype in ['fp16', 'bf16']
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
gen.append(k)
return gen

View File

@@ -560,6 +560,12 @@ def get_fwd_blobs(kernel_filter : Optional[str], receipt, optdim_list, mask_impl
if not cond:
continue
# fp32 only
if receipt == 800 or receipt == 801:
cond = dtype == 'fp32'
if not cond:
continue
api_pool.register_traits(k.api_trait())
gen.append(k)

View File

@@ -43,7 +43,7 @@ auto create_args(int argc, char* argv[])
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("dbias", "0", "output bias gradient or not")
.insert("prec", "fp16", "data type. fp16 or bf16")
.insert("prec", "fp16", "data type. fp32/fp16/bf16")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
@@ -159,7 +159,11 @@ int main(int argc, char* argv[])
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
if(data_type == "fp32")
{
return run<FmhaBwdFp32>(arg_parser) == bwd_result::success ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<FmhaBwdFp16>(arg_parser) == bwd_result::success ? 0 : -2;
}

View File

@@ -67,7 +67,7 @@ auto create_args(int argc, char* argv[])
"n or 0, no bias\n"
"e(lementwise) or 1, elementwise bias with 1*1*s*s. e:1, 1*h*s*s. e:2, b*h*s*s\n"
"a(libi) or 2, alibi with 1*h. a:1, b*h")
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
.insert("prec", "fp16", "data type. fp32/fp16/bf16/fp8/bf8")
.insert("mask",
"0",
"0: no mask, 1: top-left(same as 't'), 2:bottom-right(same as 'b')\n"
@@ -227,7 +227,11 @@ int main(int argc, char* argv[])
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
if(data_type == "fp32")
{
return run<FmhaFwdFp32>(arg_parser) == fwd_result::success ? 0 : -2;
}
else if(data_type == "fp16")
{
return run<FmhaFwdFp16>(arg_parser) == fwd_result::success ? 0 : -2;
}

View File

@@ -35,6 +35,14 @@ auto get_elimit(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaBwdFp32>(ck_tile::index_t /*hdim_q*/, ck_tile::index_t /*hdim_v*/)
{
double rtol = 1e-4;
double atol = 1e-4;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaBwdBf16>(ck_tile::index_t hdim_q, ck_tile::index_t hdim_v)
{
@@ -77,7 +85,9 @@ bwd_result fmha_bwd_run(mode_enum mode,
std::optional<std::string> json = std::nullopt)
{
const std::string data_type = []() {
if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdFp16>)
if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdFp32>)
return "fp32";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdFp16>)
return "fp16";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaBwdBf16>)
return "bf16";
@@ -776,7 +786,7 @@ bwd_result fmha_bwd_run(mode_enum mode,
// non-deterministic kernels use atomic add to write dq
// Some block may be skipped with causal mask and dq are not set to zeros
// In these cases thus we need to zero out it first
if(!deterministic || mask.type == mask_enum::no_mask)
if(!deterministic || mask.type != mask_enum::no_mask)
dq_acc_buf.SetZero();
ck_tile::stream_config stream_config_v{nullptr, true, 0, 0, 1};

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -17,6 +17,10 @@
#include <utility>
#include <variant>
struct FmhaFwdFp32
{
};
struct FmhaFwdFp16
{
};
@@ -48,6 +52,22 @@ struct FmhaFwdFp8Fp32
template <typename DataType>
struct FmhaFwdTypeConfig;
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp32>
{
using QDataType = float;
using KDataType = float;
using VDataType = float;
using BiasDataType = float;
using RandValOutputDataType = uint8_t;
using LSEDataType = float; // data type for lse(logsumexp L_j = max_j + log(l_j))
using SaccDataType = float; // data type for first gemm accumulation
using SMPLComputeDataType = float; // data type for reduction, softmax
using PDataType = float; // data type for A matrix of second gemm
using OaccDataType = float; // data type for second gemm accumulation
using ODataType = float;
};
template <>
struct FmhaFwdTypeConfig<FmhaFwdFp16>
{

View File

@@ -41,6 +41,14 @@ auto get_elimit(std::string /*init_method*/)
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaFwdFp32>(std::string /*init_method*/)
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<FmhaFwdBf16>(std::string /*init_method*/)
{
@@ -180,7 +188,9 @@ fwd_result fmha_fwd_run(mode_enum mode,
std::optional<std::string> json = std::nullopt)
{
const std::string data_type = []() {
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp16>)
if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp32>)
return "fp32";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdFp16>)
return "fp16";
else if constexpr(std::is_same_v<DataTypeConfig, FmhaFwdBf16>)
return "bf16";