mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
* create files for xdlops
* working on blockwise_gemm_xdlops
* add KReduction
* add m/n repeats
* add 2x2 pipeline
* added 128x128 wavegemm
* use StaticBuffer of vector_type
* break vector type to blk_size
* add kpack into xldops_gemm and blockwise_gemm
* abroadcast only
* add fp32 mfma instructions
* adding fp16 mfma
* pack half4_t
* rename kperwave to kpack
* add 32x32x8fp16
* add fp16 mfma
* clean code
* clean code
* V4r4 xdlops kpack (#35)
* add kpack with incorrect results
* bug fix for make_dynamic_naive_tensor_descriptor_aligned_v2
* add 1x1 kernel
* add gridwise_gemm_v2 - single_buffer
* enabled dwordx4 for fp16
Co-authored-by: Chao Liu <chao.liu2@amd.com>
* refactor fwd-v4r4-xdlops
* add v4r4-nhwc-xdlop
* improve some perf of nhwc and nchw by tuning parameters, and change scheuduling in gridwise-gemm loop
* tweak scheduling in gridwise gemm
* add v4r3 with a single output copy
* init commit: output with slice win
* adding sliceWin
* add multiple repeats pattern
* starting adding bwd-v4r1-xdlops
* use tuple as SrcBuffer
* adding bwd-data v4r1 nhwc xdlops
* fix bug in make_dynamic_naive_tensor_descriptor_aligned_v2()
* fix bug in host bwd-data conv
* initial implementation of bwd-data v4r1 nhwc xdlops
* add launch bound flags
* enable launch bound
* add m/nrepeat=4
* tweak bwd-data v4r1 nhwc xdlops
* added bwd-data v4r1 nhwc xlops with output A and weight B
* add fwd-v4r4 nhwc xdlops, A input, B weight, C output
Co-authored-by: Chao Liu <chao.liu2@amd.com>
[ROCm/composable_kernel commit: 3835318cc3]
500 lines
17 KiB
C++
500 lines
17 KiB
C++
#ifndef CK_AMD_XDLOPS_HPP
|
|
#define CK_AMD_XDLOPS_HPP
|
|
|
|
#include "float_type.hpp"
|
|
|
|
namespace ck {
|
|
|
|
// A, B, C, cbsz, abid, blgp
|
|
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
|
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
|
|
|
|
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
|
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
|
|
|
|
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
|
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
|
|
|
|
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
|
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
|
|
|
|
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
|
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
|
|
|
|
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
|
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
|
|
|
|
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
|
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
|
|
|
|
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
|
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
|
|
|
|
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
|
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
|
|
|
|
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
|
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
|
|
|
|
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
|
|
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
|
|
|
|
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
|
|
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
|
|
|
|
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
|
|
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
|
|
|
|
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
|
|
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
|
|
|
|
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
|
|
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_32x32x1f32;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
|
1,
|
|
0,
|
|
0);
|
|
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
|
|
1,
|
|
1,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
|
1,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_32x32x2f32;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_16x16x4f32;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_16x16x1f32;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
|
|
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
|
2,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_4x4x1f32;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
|
4,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
|
4,
|
|
0,
|
|
0);
|
|
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
|
|
4,
|
|
1,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_32x32x4f16;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
|
1,
|
|
0,
|
|
0);
|
|
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
|
|
1,
|
|
1,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
|
|
1,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_32x32x8f16;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_16x16x16f16;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
|
0,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_16x16x4f16;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
|
|
2,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t MPerWave, index_t NPerWave, index_t COffset>
|
|
struct intrin_mfma_f32_4x4x4f16;
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
|
4,
|
|
0,
|
|
0);
|
|
}
|
|
};
|
|
|
|
template <index_t COffset>
|
|
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
|
|
{
|
|
template <class FloatC>
|
|
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
|
|
{
|
|
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
|
|
4,
|
|
0,
|
|
0);
|
|
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
|
|
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
|
reg_a,
|
|
reg_b,
|
|
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
|
|
4,
|
|
1,
|
|
0);
|
|
}
|
|
};
|
|
|
|
#if 0
|
|
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
|
|
struct intrin_mfma_f32_32x32x2bf16;
|
|
|
|
template <index_t AStride, index_t BStride>
|
|
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
|
|
{
|
|
__device__ static c_vec32_4_t::VecType
|
|
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
|
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
|
|
|
reg_c.s.z =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
|
|
reg_c.s.w =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
|
|
|
|
return reg_c;
|
|
}
|
|
};
|
|
|
|
template <index_t AStride, index_t BStride>
|
|
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
|
|
{
|
|
__device__ static c_vec32_4_t::VecType
|
|
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
|
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
|
|
|
reg_c.s.z =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
|
|
reg_c.s.w =
|
|
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
|
|
|
|
return reg_c;
|
|
}
|
|
};
|
|
|
|
template <index_t AStride, index_t BStride>
|
|
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
|
|
{
|
|
__device__ static c_vec32_2_t::VecType
|
|
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
|
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
|
|
|
return reg_c;
|
|
}
|
|
};
|
|
|
|
template <index_t AStride, index_t BStride>
|
|
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
|
|
{
|
|
__device__ static c_vec32_1_t::VecType
|
|
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
|
|
|
|
return reg_c;
|
|
}
|
|
};
|
|
|
|
template <index_t AStride, index_t BStride>
|
|
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
|
|
{
|
|
__device__ static c_vec32_1_t::VecType
|
|
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
|
return reg_c;
|
|
}
|
|
};
|
|
|
|
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
|
|
const ushort2_t* reg_b,
|
|
c_vec16_1_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
|
|
return reg_c;
|
|
}
|
|
|
|
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
|
|
const ushort2_t* reg_b,
|
|
c_vec4_1_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
|
|
return reg_c;
|
|
}
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
|
|
const ushort2_t* reg_b,
|
|
c_vec16_1_t::VecType reg_c);
|
|
|
|
template <>
|
|
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
|
|
const ushort2_t* reg_b,
|
|
c_vec16_1_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
|
|
return reg_c;
|
|
}
|
|
|
|
template <>
|
|
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
|
|
const ushort2_t* reg_b,
|
|
c_vec16_1_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
|
|
return reg_c;
|
|
}
|
|
|
|
template <index_t MPerWave, index_t NPerWave>
|
|
struct intrin_mfma_f32_4x4x2bf16;
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_4x4x2bf16<4, 64>
|
|
{
|
|
__device__ static c_vec4_1_t::VecType
|
|
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
|
|
return reg_c;
|
|
}
|
|
};
|
|
|
|
template <>
|
|
struct intrin_mfma_f32_4x4x2bf16<8, 64>
|
|
{
|
|
__device__ static c_vec4_2_t::VecType
|
|
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
|
|
{
|
|
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
|
|
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
|
|
return reg_c;
|
|
}
|
|
};
|
|
|
|
#endif
|
|
|
|
} // namespace ck
|
|
#endif
|