mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Replace llvm Intrinsics with clang buildins (#65)
* test mfma builtins
* add fp16 buildins
* add int8 buildins
* add bfl16 buildins
* simplify host conv forward
* clean
* clean
[ROCm/composable_kernel commit: 6d92959ad3]
This commit is contained in:
@@ -5,77 +5,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A, B, C, cbsz, abid, blgp
|
||||
// fp32
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
||||
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
||||
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
|
||||
|
||||
// fp16
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
||||
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
||||
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
|
||||
|
||||
// bfp16
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k");
|
||||
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
|
||||
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
|
||||
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
|
||||
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
|
||||
|
||||
// int8
|
||||
extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8(
|
||||
int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8");
|
||||
|
||||
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8(
|
||||
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8");
|
||||
|
||||
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8(
|
||||
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8");
|
||||
|
||||
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8(
|
||||
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8");
|
||||
|
||||
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8(
|
||||
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8");
|
||||
|
||||
// fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
@@ -86,9 +15,9 @@ struct intrin_mfma_f32_32x32x1f32<64, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
||||
}
|
||||
};
|
||||
@@ -99,7 +28,7 @@ struct intrin_mfma_f32_32x32x1f32<32, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -113,7 +42,7 @@ struct intrin_mfma_f32_32x32x2f32<32, 32>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x2f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -127,7 +56,7 @@ struct intrin_mfma_f32_16x16x4f32<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -141,8 +70,7 @@ struct intrin_mfma_f32_16x16x1f32<16, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -156,7 +84,7 @@ struct intrin_mfma_f32_4x4x1f32<4, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -167,9 +95,9 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x1f32(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
||||
}
|
||||
};
|
||||
@@ -184,9 +112,9 @@ struct intrin_mfma_f32_32x32x4f16<64, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_c.template AsType<float32_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<1>{}], 1, 1, 0);
|
||||
}
|
||||
};
|
||||
@@ -197,7 +125,7 @@ struct intrin_mfma_f32_32x32x4f16<32, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_c.template AsType<float32_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float32_t>()[Number<0>{}], 1, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -211,7 +139,7 @@ struct intrin_mfma_f32_32x32x8f16<32, 32>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -225,7 +153,7 @@ struct intrin_mfma_f32_16x16x16f16<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -239,7 +167,7 @@ struct intrin_mfma_f32_16x16x4f16<16, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 2, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -253,7 +181,7 @@ struct intrin_mfma_f32_4x4x4f16<4, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -264,9 +192,9 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 4, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_c.template AsType<float4_t>()(Number<1>{}) = __builtin_amdgcn_mfma_f32_4x4x4f16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<1>{}], 4, 1, 0);
|
||||
}
|
||||
};
|
||||
@@ -281,9 +209,8 @@ struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -296,9 +223,8 @@ struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -311,7 +237,7 @@ struct intrin_mfma_f32_32x32x4bf16<32, 32>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -325,7 +251,7 @@ struct intrin_mfma_f32_16x16x8bf16<16, 16>
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
@@ -340,12 +266,12 @@ struct intrin_mfma_i32_32x32x8i8<32, 32>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a),
|
||||
bit_cast<int>(reg_b),
|
||||
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
__builtin_amdgcn_mfma_i32_32x32x8i8(bit_cast<int>(reg_a),
|
||||
bit_cast<int>(reg_b),
|
||||
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -359,12 +285,12 @@ struct intrin_mfma_i32_16x16x16i8<16, 16>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a),
|
||||
bit_cast<int>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
__builtin_amdgcn_mfma_i32_16x16x16i8(bit_cast<int>(reg_a),
|
||||
bit_cast<int>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -169,6 +169,8 @@ struct DynamicBuffer
|
||||
is_same<remove_cvref_t<X>, int8x2_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x8_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value) ||
|
||||
(is_same<remove_cvref_t<T>, int8x8_t>::value &&
|
||||
@@ -202,6 +204,14 @@ struct DynamicBuffer
|
||||
*c_style_pointer_cast<int32_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*c_style_pointer_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*c_style_pointer_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cvref_t<T>, int8x4_t>::value &&
|
||||
is_same<remove_cvref_t<X>, int8x4_t>::value)
|
||||
{
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
template <ck::index_t BlockSize,
|
||||
typename FloatAB,
|
||||
@@ -70,6 +71,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
using ElementwiseOperation = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
@@ -79,6 +82,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K,
|
||||
CMNGridDesc,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
@@ -87,7 +93,6 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -95,7 +100,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
@@ -103,17 +108,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridStepHacks,
|
||||
BGridStepHacks,
|
||||
CGridStepHacks,
|
||||
AGridMoveSliceWindowStepHacks,
|
||||
BGridMoveSliceWindowStepHacks,
|
||||
CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
{
|
||||
std::cout << "a_grid_desc_k0_m_k1{" << a_grid_desc_k0_m_k1.GetLength(I0) << ", "
|
||||
@@ -152,6 +150,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
auto element_op_ = ElementwiseOperation{};
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
@@ -162,6 +162,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
remove_reference_t<AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<BGridDesc_K0_N_K>,
|
||||
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
@@ -176,6 +179,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
element_op_,
|
||||
element_op_,
|
||||
element_op_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
else
|
||||
@@ -187,6 +193,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
remove_reference_t<AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<BGridDesc_K0_N_K>,
|
||||
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
ElementwiseOperation,
|
||||
remove_reference_t<Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
@@ -201,6 +210,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
a_grid_desc_k0_m_k1,
|
||||
b_grid_desc_k0_n_k1,
|
||||
c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc,
|
||||
element_op_,
|
||||
element_op_,
|
||||
element_op_,
|
||||
block_2_ctile_map);
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
|
||||
Reference in New Issue
Block a user