Add non scaled ops

This commit is contained in:
Rostyslav Geyyer
2025-03-12 16:03:23 +00:00
parent d466096e25
commit a2200a4e42

View File

@@ -507,6 +507,34 @@ struct intrin_mfma_f32_32x32x64f8f6f4<32, 32>
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float16_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float16_t>()[Number<0>{}],
4, // cbsz
4, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};
@@ -682,6 +710,33 @@ struct intrin_mfma_f32_16x16x128f8f6f4<16, 16>
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
template <class FloatC>
__device__ static void Run(const f4x32_t& reg_a, const f4x32_t& reg_b, FloatC& reg_c)
{
#if defined(__gfx950__)
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
using arg_type = int32x8_t;
reg_c.template AsType<float4_t>()(Number<0>{}) =
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
reg_c.template AsType<float4_t>()[Number<0>{}],
4, // cbsz
4, // blgp
0, // OPSEL
0,
0, // OPSEL
0);
#else
ignore = reg_a;
ignore = reg_b;
ignore = reg_c;
#endif
}
};