mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Remove CK_USE_AMD_MFMA_GFX950 (#1935)
* Add runtime check in example_gemm_xdl_streamk for gfx950 * Add runtime check in grouped conv fwd examples for gfx950 * Disable CK_USE_AMD_MFMA_GFX950 * Add new instances for gfx950 * Fix test_gemm_universal on gfx950
This commit is contained in:
@@ -1053,40 +1053,49 @@ struct MfmaSelector
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 32, 32, int8_t, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x32i8;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x64i8;
|
||||
}
|
||||
#elif defined(__gfx942__)
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x16i8;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
}
|
||||
#else
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x8i8;
|
||||
}
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 32, 32, int8_t, true>()
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return MfmaInstr::mfma_i32_32x32x16i8;
|
||||
#else
|
||||
return MfmaInstr::mfma_i32_32x32x8i8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16, int8_t, false>()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
return MfmaInstr::mfma_i32_16x16x64i8;
|
||||
#elif defined(__gfx942__)
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
#else
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<int8_t, 16, 16, int8_t, true>()
|
||||
{
|
||||
#if defined(__gfx942__) || defined(__gfx950__)
|
||||
return MfmaInstr::mfma_i32_16x16x32i8;
|
||||
#else
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
constexpr auto GetMfma<f8_t, 32, 32>()
|
||||
@@ -1440,12 +1449,13 @@ struct XdlopsGemm
|
||||
}
|
||||
|
||||
// Falls back to single rate instruction on gfx950 if KPack <= 4; no change on gfx942-
|
||||
static constexpr auto
|
||||
mfma = MfmaSelector < base_type,
|
||||
MPerXdlops, NPerXdlops, additional_type,
|
||||
((is_same<base_type, half_t>::value || is_same<base_type, bhalf_t>::value) && KPack <= 4)
|
||||
? true
|
||||
: false > {};
|
||||
static constexpr auto mfma = MfmaSelector < base_type, MPerXdlops, NPerXdlops, additional_type,
|
||||
(((is_same<base_type, half_t>::value ||
|
||||
is_same<base_type, bhalf_t>::value) &&
|
||||
KPack <= 4) ||
|
||||
(is_same<base_type, int8_t>::value && KPack <= 8))
|
||||
? true
|
||||
: false > {};
|
||||
|
||||
static constexpr auto mfma_instr = mfma.selected_mfma;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user