mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Compile for gfx908 and gfx90a (#130)
* adding compilation for multiple targets * fix build * clean * update Jekinsfile * update readme * update Jenkins * use ck::half_t instead of ushort for bf16 * rename enum classes * clean * rename * clean
This commit is contained in:
@@ -476,7 +476,7 @@ struct MfmaSelector
|
||||
template <>
|
||||
static constexpr auto GetMfma<bhalf_t, 32, 32>()
|
||||
{
|
||||
#if defined(CK_AMD_GPU_GFX90A)
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x4bf16;
|
||||
@@ -486,7 +486,7 @@ struct MfmaSelector
|
||||
template <>
|
||||
static constexpr auto GetMfma<bhalf_t, 16, 16>()
|
||||
{
|
||||
#if defined(CK_AMD_GPU_GFX90A)
|
||||
#if defined(CK_USE_AMD_MFMA_BF16_1K_OP)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x8bf16;
|
||||
|
||||
Reference in New Issue
Block a user