mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[FMHA FWD] gfx950 Accuracy enhancement & bug fix (#2900)
* disable cast_tile_pk_fp16_fp32 on gfx950 * fix wrong encoding when hdim is not exponentiation of 2 --------- Co-authored-by: asleepzzz <hanwen.chang@amd.com>
This commit is contained in:
@@ -231,7 +231,7 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8_fp32(const InTensor& in_dstr_tensors)
|
||||
template <typename OutDataType, typename InTensor>
|
||||
CK_TILE_DEVICE auto cast_tile_pk_fp16_fp32(const InTensor& in_dstr_tensors)
|
||||
{
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
|
||||
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
|
||||
// This API is designed to use the _pk_ serious of function
|
||||
constexpr auto in_tile_dstr = InTensor::get_tile_distribution();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user