Files
composable_kernel/composable_kernel/include/utility/amd_xdlops.hpp
zjing14 2331d228e2 xdlops_v4r4_fwd fp32/fp16 (#34)
* create files for xdlops

* working on blockwise_gemm_xdlops

* add KReduction

* add m/n repeats

* add 2x2 pipeline

* added 128x128 wavegemm

* use StaticBuffer of vector_type

* break vector type to blk_size

* add kpack into xldops_gemm and blockwise_gemm

* abroadcast only

* add fp32 mfma instructions

* adding fp16 mfma

* pack half4_t

* rename kperwave to kpack

* add 32x32x8fp16

* add fp16 mfma

* clean code

* clean code

* V4r4 xdlops kpack (#35)

* add kpack with incorrect results

* bug fix for make_dynamic_naive_tensor_descriptor_aligned_v2

* add 1x1 kernel

* add gridwise_gemm_v2 - single_buffer

* enabled dwordx4 for fp16

Co-authored-by: Chao Liu <chao.liu2@amd.com>

* refactor fwd-v4r4-xdlops

* add v4r4-nhwc-xdlop

* improve some perf of nhwc and nchw by tuning parameters, and change scheuduling in gridwise-gemm loop

* tweak scheduling in gridwise gemm

* add v4r3 with a single output copy

* init commit: output with slice win

* adding sliceWin

* add multiple repeats pattern

* starting adding bwd-v4r1-xdlops

* use tuple as SrcBuffer

* adding bwd-data v4r1 nhwc xdlops

* fix bug in make_dynamic_naive_tensor_descriptor_aligned_v2()

* fix bug in host bwd-data conv

* initial implementation of bwd-data v4r1 nhwc xdlops

* add launch bound flags

* enable launch bound

* add m/nrepeat=4

* tweak bwd-data v4r1 nhwc xdlops

* added bwd-data v4r1 nhwc xlops with output A and weight B

* add fwd-v4r4 nhwc xdlops, A input, B weight, C output

Co-authored-by: Chao Liu <chao.liu2@amd.com>

[ROCm/composable_kernel commit: 3835318cc3]
2021-07-01 14:33:00 -05:00

500 lines
17 KiB
C++

#ifndef CK_AMD_XDLOPS_HPP
#define CK_AMD_XDLOPS_HPP
#include "float_type.hpp"
namespace ck {
// A, B, C, cbsz, abid, blgp
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f32");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
float, float, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x1f32");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16f16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
half4_t, half4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x4f16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x8bf16");
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
ushort2_t, ushort2_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x2bf16");
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x1f32;
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<64, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_32x32x1f32<32, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x2f32;
template <index_t COffset>
struct intrin_mfma_f32_32x32x2f32<32, 32, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x2f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x4f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f32<16, 16, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x1f32;
template <index_t COffset>
struct intrin_mfma_f32_16x16x1f32<16, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x1f32;
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<4, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_4x4x1f32<8, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x4f16;
template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<64, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float32_t>()[Number<0>{}],
1,
1,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_32x32x4f16<32, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float32_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float32_t>()[Number<0>{}],
1,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_32x32x8f16;
template <index_t COffset>
struct intrin_mfma_f32_32x32x8f16<32, 32, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_32x32x8f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x16f16;
template <index_t COffset>
struct intrin_mfma_f32_16x16x16f16<16, 16, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x16f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
0,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_16x16x4f16;
template <index_t COffset>
struct intrin_mfma_f32_16x16x4f16<16, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float16_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float16_t>()[Number<0>{}],
2,
0,
0);
}
};
template <index_t MPerWave, index_t NPerWave, index_t COffset>
struct intrin_mfma_f32_4x4x4f16;
template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<4, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
}
};
template <index_t COffset>
struct intrin_mfma_f32_4x4x4f16<8, 64, COffset>
{
template <class FloatC>
__device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c)
{
reg_c(Number<COffset>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset>{}].template AsType<float4_t>()[Number<0>{}],
4,
0,
0);
reg_c(Number<COffset + 1>{}).template AsType<float4_t>()(Number<0>{}) =
llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
reg_a,
reg_b,
reg_c[Number<COffset + 1>{}].template AsType<float4_t>()[Number<0>{}],
4,
1,
0);
}
};
#if 0
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16;
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
{
__device__ static c_vec32_4_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
reg_c.s.z =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
reg_c.s.w =
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
{
__device__ static c_vec32_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
return reg_c;
}
};
template <index_t AStride, index_t BStride>
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
{
__device__ static c_vec32_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
return reg_c;
}
};
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c);
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
return reg_c;
}
template <>
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
const ushort2_t* reg_b,
c_vec16_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
return reg_c;
}
template <index_t MPerWave, index_t NPerWave>
struct intrin_mfma_f32_4x4x2bf16;
template <>
struct intrin_mfma_f32_4x4x2bf16<4, 64>
{
__device__ static c_vec4_1_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
return reg_c;
}
};
template <>
struct intrin_mfma_f32_4x4x2bf16<8, 64>
{
__device__ static c_vec4_2_t::VecType
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
{
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
return reg_c;
}
};
#endif
} // namespace ck
#endif