Remove CK_USE_AMD_MFMA_GFX950 (#1935)

* Add runtime check in example_gemm_xdl_streamk for gfx950

* Add runtime check in grouped conv fwd examples for gfx950

* Disable CK_USE_AMD_MFMA_GFX950

* Add new instances for gfx950

* Fix test_gemm_universal on gfx950
This commit is contained in:
jefyang1
2025-03-04 10:32:25 -08:00
committed by GitHub
parent 540a6da40b
commit c95bda93ba
186 changed files with 3272 additions and 883 deletions

View File

@@ -1053,40 +1053,49 @@ struct MfmaSelector
#endif
}
template <>
constexpr auto GetMfma<int8_t, 32, 32, int8_t, false>()
{
#if defined(__gfx950__)
template <>
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x32i8;
}
template <>
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x64i8;
}
#elif defined(__gfx942__)
template <>
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x16i8;
}
template <>
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x32i8;
}
#else
template <>
constexpr auto GetMfma<int8_t, 32, 32>()
{
return MfmaInstr::mfma_i32_32x32x8i8;
}
template <>
constexpr auto GetMfma<int8_t, 16, 16>()
{
return MfmaInstr::mfma_i32_16x16x16i8;
}
#endif
}
template <>
constexpr auto GetMfma<int8_t, 32, 32, int8_t, true>()
{
#if defined(__gfx942__) || defined(__gfx950__)
return MfmaInstr::mfma_i32_32x32x16i8;
#else
return MfmaInstr::mfma_i32_32x32x8i8;
#endif
}
template <>
constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
{
#if defined(__gfx950__)
return MfmaInstr::mfma_i32_16x16x64i8;
#elif defined(__gfx942__)
return MfmaInstr::mfma_i32_16x16x32i8;
#else
return MfmaInstr::mfma_i32_16x16x16i8;
#endif
}
template <>
constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
{
#if defined(__gfx942__) || defined(__gfx950__)
return MfmaInstr::mfma_i32_16x16x32i8;
#else
return MfmaInstr::mfma_i32_16x16x16i8;
#endif
}
template <>
constexpr auto GetMfma<f8_t, 32, 32>()
@@ -1440,12 +1449,13 @@ struct XdlopsGemm
}
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
static constexpr auto
mfma = MfmaSelector < base_type,
MPerXdlops, NPerXdlops, additional_type,
((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) && KPack <= 4)
? true
: false > {};
static constexpr auto mfma = MfmaSelector < base_type, MPerXdlops, NPerXdlops, additional_type,
(((is_same<base_type, half_t>::value ||
is_same<base_type, bhalf_t>::value) &&
KPack <= 4) ||
(is_same<base_type, int8_t>::value && KPack <= 8))
? true
: false > {};
static constexpr auto mfma_instr = mfma.selected_mfma;