mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
Fix arch limitation bug (#639)
[ROCm/composable_kernel commit: ea028ac65a]
This commit is contained in:
@@ -25,7 +25,7 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
|
||||
// delete them.
|
||||
// amd_assembly_wmma_f32_16x16x16_f16_w32(
|
||||
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<float8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
|
||||
#else
|
||||
@@ -46,7 +46,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w32<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<float8_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<float8_t>()[Number<0>{}]);
|
||||
@@ -71,7 +71,7 @@ struct intrin_wmma_f16_16x16x16_f16_w32<16, 16, Opsel>
|
||||
// opsel usage
|
||||
// false: D0.[0:15] = result
|
||||
// true : D0.[16:31]= result
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<half16_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<half16_t>()[Number<0>{}], Opsel);
|
||||
#else
|
||||
@@ -95,7 +95,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w32<16, 16, Opsel>
|
||||
// opsel usage
|
||||
// false: D0.[0:15] = result
|
||||
// true : D0.[16:31]= result
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<bhalf16_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32(
|
||||
reg_a, reg_b, reg_c.template AsType<bhalf16_t>()[Number<0>{}], Opsel);
|
||||
@@ -117,7 +117,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w32<16, 16, neg_a, neg_b, clamp>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<int32x8_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32(
|
||||
neg_a,
|
||||
@@ -145,7 +145,7 @@ struct intrin_wmma_f32_16x16x16_f16_w64<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half16_t& reg_a, const half16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f32_16x16x16_f16_w64(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
|
||||
#else
|
||||
@@ -166,7 +166,7 @@ struct intrin_wmma_f32_16x16x16_bf16_w64<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const bhalf16_t& reg_a, const bhalf16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}]);
|
||||
@@ -191,7 +191,7 @@ struct intrin_wmma_f16_16x16x16_f16_w64<16, 16, Opsel>
|
||||
// opsel usage
|
||||
// false: D0.[0:15] = result
|
||||
// true : D0.[16:31]= result
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<half8_t>()(Number<0>{}) = __builtin_amdgcn_wmma_f16_16x16x16_f16_w64(
|
||||
reg_a, reg_b, reg_c.template AsType<half8_t>()[Number<0>{}], Opsel);
|
||||
#else
|
||||
@@ -215,7 +215,7 @@ struct intrin_wmma_bf16_16x16x16_bf16_w64<16, 16, Opsel>
|
||||
// opsel usage
|
||||
// false: D0.[0:15] = result
|
||||
// true : D0.[16:31]= result
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<bhalf8_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64(
|
||||
reg_a, reg_b, reg_c.template AsType<bhalf8_t>()[Number<0>{}], Opsel);
|
||||
@@ -237,7 +237,7 @@ struct intrin_wmma_i32_16x16x16_iu8_w64<16, 16, neg_a, neg_b, clamp>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x16_t& reg_a, const int8x16_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
#if defined(__gfx11__)
|
||||
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__)
|
||||
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
||||
__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64(
|
||||
neg_a,
|
||||
|
||||
Reference in New Issue
Block a user