mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-06 15:54:31 +00:00
Add ops
This commit is contained in:
@@ -543,6 +543,40 @@ struct intrin_mfma_scale_f32_32x32x64f8f6f4<32, 32>
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f4x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_32x32x64_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
reg_c.template AsType<float16_t>()[Number<0>{}],
|
||||
4, // cbsz
|
||||
4, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -579,6 +613,39 @@ struct intrin_mfma_scale_f32_16x16x128f8f6f4<16, 16>
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const f4x32_t& reg_a,
|
||||
const int32_t scale_a,
|
||||
const f4x32_t& reg_b,
|
||||
const int32_t scale_b,
|
||||
FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
int32x4_t arg_a = bit_cast<int32x4_t>(reg_a);
|
||||
int32x4_t arg_b = bit_cast<int32x4_t>(reg_b);
|
||||
|
||||
using arg_type = int32x8_t;
|
||||
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_mfma_scale_f32_16x16x128_f8f6f4(
|
||||
arg_type{arg_a[0], arg_a[1], arg_a[2], arg_a[3], 0, 0, 0, 0},
|
||||
arg_type{arg_b[0], arg_b[1], arg_b[2], arg_b[3], 0, 0, 0, 0},
|
||||
reg_c.template AsType<float4_t>()[Number<0>{}],
|
||||
4, // cbsz
|
||||
4, // blgp
|
||||
0, // OPSEL
|
||||
scale_a,
|
||||
0, // OPSEL
|
||||
scale_b);
|
||||
#else
|
||||
ignore = reg_a;
|
||||
ignore = scale_a;
|
||||
ignore = reg_b;
|
||||
ignore = scale_b;
|
||||
ignore = reg_c;
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user