[FMHA FWD] gfx950 Accuracy enhancement & bug fix (#2900)

* disable cast_tile_pk_fp16_fp32 on gfx950

* fix wrong encoding when hdim is not exponentiation of 2

---------

Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
Haocong WANG
2025-09-24 00:59:41 +08:00
committed by GitHub
parent 7b16782d7c
commit 959df2a155
2 changed files with 4 additions and 3 deletions

View File

@@ -231,7 +231,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
template <typename OutDataType, typename InTensor>
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
// This API is designed to use the _pk_ serious of function
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();