mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 17:26:00 +00:00
Add bfp16/int8 support into XDL GEMM operator (#50)
* init StaticBufferV2 * clean * adopt old output stage for staticBufferV2 * clean * remove hack * clean * clean * add parameters * clean code * move c_buffer alloc into blockwise gemm * add adaptors for m/n_thread_data_on_grid * tweak gemm * adjust blockwise_gemm_xdlops * tweak * update conv * update script * adding bwd 1x1 * update script * adding 1x1 bwd * debugging bwd 1x1 failure * update script * update script * test * test v100 * add bf16_1k * clang-format * clean * add bfp16 for gfx908 * add verification * clean up * clean code * restore bfl16 * clean * add bfp16 support into gemm_driver * apply new generator to other drivers * add int8 support * cleanb * clean * clean * clean Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Chao Liu <lc.roy86@gmail.com> Co-authored-by: root <root@hayabusa6111.amd.com>
This commit is contained in:
@@ -12,18 +12,19 @@ enum struct MfmaInstr
|
||||
mfma_f32_32x32x1xf32 = 0,
|
||||
mfma_f32_16x16x1xf32,
|
||||
mfma_f32_4x4x1xf32,
|
||||
mfma_f32_32x32x2xf32, // k reduction
|
||||
mfma_f32_16x16x4xf32, // k reduction
|
||||
mfma_f32_32x32x2xf32,
|
||||
mfma_f32_16x16x4xf32,
|
||||
mfma_f32_32x32x4f16,
|
||||
mfma_f32_16x16x4f16,
|
||||
mfma_f32_4x4x4f16,
|
||||
mfma_f32_32x32x8f16, // k reduction
|
||||
mfma_f32_16x16x16f16, // k reduction
|
||||
mfma_f32_32x32x2bf16,
|
||||
mfma_f32_16x16x2bf16,
|
||||
mfma_f32_4x4x2bf16,
|
||||
mfma_f32_32x32x4bf16, // k reduction
|
||||
mfma_f32_16x16x8bf16, // k reduction
|
||||
mfma_f32_32x32x8f16,
|
||||
mfma_f32_16x16x16f16,
|
||||
mfma_f32_32x32x8bf16_1k,
|
||||
mfma_f32_16x16x16bf16_1k,
|
||||
mfma_f32_32x32x4bf16,
|
||||
mfma_f32_16x16x8bf16,
|
||||
mfma_i32_32x32x8i8,
|
||||
mfma_i32_16x16x16i8,
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -250,9 +251,8 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
@@ -260,26 +260,38 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
intrin_mfma_f32_32x32x8bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
|
||||
p_a, p_b, reg_c);
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x16bf16_1k>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x16bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
|
||||
intrin_mfma_f32_32x32x4bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -329,24 +332,37 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
|
||||
intrin_mfma_f32_16x16x8bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16>
|
||||
struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_i32_32x32x8i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
@@ -354,60 +370,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16>
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
|
||||
intrin_mfma_i32_16x16x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_4x4x2bf16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 64;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 4;
|
||||
static constexpr index_t n_per_blk = 64;
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
|
||||
struct MfmaSelector
|
||||
{
|
||||
@@ -498,73 +473,37 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_f32_4x4x4f16;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 128, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 64, 128>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 64, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 64, 32>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 32, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 64, 16>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 16, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 8, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 4, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 32, 32>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
|
||||
#if defined(CK_AMD_GPU_GFX90A)
|
||||
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x4bf16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 16, 16>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
|
||||
}
|
||||
#if defined(CK_AMD_GPU_GFX90A)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x8bf16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x8i8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
}
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
|
||||
@@ -686,8 +625,8 @@ struct XdlopsGemm
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
|
||||
is_same<base_type, ushort>::value,
|
||||
"base base_type must be float, half, ushort!");
|
||||
is_same<base_type, ushort>::value || is_same<base_type, int8_t>::value,
|
||||
"base base_type must be float, half, ushort, and int8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
|
||||
@@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
|
||||
__device__ int16_t
|
||||
__device__ ushort
|
||||
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
|
||||
|
||||
__device__ ushort2_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
|
||||
|
||||
__device__ ushort4_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
|
||||
|
||||
__device__ int32_t
|
||||
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
@@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
@@ -228,6 +255,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
@@ -326,6 +354,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
return as_type<half8_t>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, ushort>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<ushort8_t>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
@@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
(is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
@@ -552,6 +606,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, ushort>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<half_t, 8> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
namespace ck {
|
||||
|
||||
// A, B, C, cbsz, abid, blgp
|
||||
// fp32
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
|
||||
|
||||
@@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
|
||||
|
||||
// fp16
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
|
||||
|
||||
@@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
|
||||
|
||||
// bfp16
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k");
|
||||
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
|
||||
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
|
||||
|
||||
@@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
|
||||
|
||||
// int8
|
||||
extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8(
|
||||
int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8");
|
||||
|
||||
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8(
|
||||
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8");
|
||||
|
||||
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8(
|
||||
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8");
|
||||
|
||||
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8(
|
||||
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8");
|
||||
|
||||
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8(
|
||||
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8");
|
||||
|
||||
// fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
|
||||
@@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
|
||||
}
|
||||
};
|
||||
|
||||
// fp16
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x4f16;
|
||||
|
||||
@@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16;
|
||||
// bfp16
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x8bf16_1k;
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
|
||||
{
|
||||
__device__ static c_vec32_4_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
reg_c.s.z =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
|
||||
reg_c.s.w =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_4_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
reg_c.s.z =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
|
||||
reg_c.s.w =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_2_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
|
||||
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec4_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c);
|
||||
template <>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
struct intrin_mfma_f32_16x16x16bf16_1k;
|
||||
|
||||
template <>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
|
||||
return reg_c;
|
||||
}
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x2bf16;
|
||||
struct intrin_mfma_f32_32x32x4bf16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x2bf16<4, 64>
|
||||
struct intrin_mfma_f32_32x32x4bf16<32, 32>
|
||||
{
|
||||
__device__ static c_vec4_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
|
||||
return reg_c;
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x8bf16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x2bf16<8, 64>
|
||||
struct intrin_mfma_f32_16x16x8bf16<16, 16>
|
||||
{
|
||||
__device__ static c_vec4_2_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
|
||||
return reg_c;
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_i32_32x32x8i8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_i32_32x32x8i8<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(as_type<int>(reg_a),
|
||||
as_type<int>(reg_b),
|
||||
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_i32_16x16x16i8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_i32_16x16x16i8<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(as_type<int>(reg_a),
|
||||
as_type<int>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user