mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add bfp16/int8 support into XDL GEMM operator (#50)
* init StaticBufferV2 * clean * adopt old output stage for staticBufferV2 * clean * remove hack * clean * clean * add parameters * clean code * move c_buffer alloc into blockwise gemm * add adaptors for m/n_thread_data_on_grid * tweak gemm * adjust blockwise_gemm_xdlops * tweak * update conv * update script * adding bwd 1x1 * update script * adding 1x1 bwd * debugging bwd 1x1 failure * update script * update script * test * test v100 * add bf16_1k * clang-format * clean * add bfp16 for gfx908 * add verification * clean up * clean code * restore bfl16 * clean * add bfp16 support into gemm_driver * apply new generator to other drivers * add int8 support * cleanb * clean * clean * clean Co-authored-by: Chao Liu <chao.liu2@amd.com> Co-authored-by: Chao Liu <lc.roy86@gmail.com> Co-authored-by: root <root@hayabusa6111.amd.com>
This commit is contained in:
@@ -12,18 +12,19 @@ enum struct MfmaInstr
|
||||
mfma_f32_32x32x1xf32 = 0,
|
||||
mfma_f32_16x16x1xf32,
|
||||
mfma_f32_4x4x1xf32,
|
||||
mfma_f32_32x32x2xf32, // k reduction
|
||||
mfma_f32_16x16x4xf32, // k reduction
|
||||
mfma_f32_32x32x2xf32,
|
||||
mfma_f32_16x16x4xf32,
|
||||
mfma_f32_32x32x4f16,
|
||||
mfma_f32_16x16x4f16,
|
||||
mfma_f32_4x4x4f16,
|
||||
mfma_f32_32x32x8f16, // k reduction
|
||||
mfma_f32_16x16x16f16, // k reduction
|
||||
mfma_f32_32x32x2bf16,
|
||||
mfma_f32_16x16x2bf16,
|
||||
mfma_f32_4x4x2bf16,
|
||||
mfma_f32_32x32x4bf16, // k reduction
|
||||
mfma_f32_16x16x8bf16, // k reduction
|
||||
mfma_f32_32x32x8f16,
|
||||
mfma_f32_16x16x16f16,
|
||||
mfma_f32_32x32x8bf16_1k,
|
||||
mfma_f32_16x16x16bf16_1k,
|
||||
mfma_f32_32x32x4bf16,
|
||||
mfma_f32_16x16x8bf16,
|
||||
mfma_i32_32x32x8i8,
|
||||
mfma_i32_16x16x16i8,
|
||||
};
|
||||
|
||||
template <MfmaInstr instr>
|
||||
@@ -250,9 +251,8 @@ struct mfma_type<MfmaInstr::mfma_f32_4x4x4f16>
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_32x32x8bf16_1k>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
@@ -260,26 +260,38 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x2bf16>
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
intrin_mfma_f32_32x32x8bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
return intrin_mfma_f32_32x32x2bf16<MPerXdlops, NPerXdlops, AStride, BStride>::run(
|
||||
p_a, p_b, reg_c);
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x16bf16_1k>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_f32_16x16x16bf16_1k<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -298,19 +310,10 @@ struct mfma_type<MfmaInstr::mfma_f32_32x32x4bf16>
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_32x32x4bf16(p_a, p_b, reg_c);
|
||||
intrin_mfma_f32_32x32x4bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -329,24 +332,37 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x8bf16>
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x8bf16(p_a, p_b, reg_c);
|
||||
intrin_mfma_f32_16x16x8bf16<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16>
|
||||
struct mfma_type<MfmaInstr::mfma_i32_32x32x8i8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 4;
|
||||
static constexpr index_t num_regs_per_blk = 16;
|
||||
static constexpr index_t num_threads_per_blk = 32;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 2;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 32;
|
||||
static constexpr index_t n_per_blk = 32;
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
intrin_mfma_i32_32x32x8i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_i32_16x16x16i8>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
@@ -354,60 +370,19 @@ struct mfma_type<MfmaInstr::mfma_f32_16x16x2bf16>
|
||||
static constexpr index_t num_threads_per_blk = 16;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 4;
|
||||
static constexpr index_t num_output_blks = 4;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 16;
|
||||
static constexpr index_t n_per_blk = 16;
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
static constexpr index_t k_per_blk = 4;
|
||||
static constexpr bool is_k_reduction = true;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
template <index_t MPerXdlops, index_t NPerXdlops, class FloatA, class FloatB, class FloatC>
|
||||
__device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_16x16x2bf16<MPerXdlops, NPerXdlops>(p_a, p_b, reg_c);
|
||||
intrin_mfma_i32_16x16x16i8<MPerXdlops, NPerXdlops>::Run(a, b, reg_c);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct mfma_type<MfmaInstr::mfma_f32_4x4x2bf16>
|
||||
{
|
||||
static constexpr index_t group_size = 4;
|
||||
static constexpr index_t num_groups_per_blk = 1;
|
||||
static constexpr index_t num_regs_per_blk = 4;
|
||||
static constexpr index_t num_threads_per_blk = 64;
|
||||
static constexpr index_t wave_size = 64;
|
||||
static constexpr index_t num_input_blks = 1;
|
||||
static constexpr index_t num_output_blks = 1;
|
||||
static constexpr index_t m_per_blk = 4;
|
||||
static constexpr index_t n_per_blk = 64;
|
||||
static constexpr index_t k_per_blk = 2;
|
||||
static constexpr bool is_k_reduction = false;
|
||||
|
||||
template <index_t MPerXdlops,
|
||||
index_t NPerXdlops,
|
||||
index_t AStride,
|
||||
index_t BStride,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ FloatC run(const FloatA* a, const FloatB* b, FloatC reg_c) const
|
||||
{
|
||||
const auto p_a = c_style_pointer_cast<const ushort2_t*>(a);
|
||||
const auto p_b = c_style_pointer_cast<const ushort2_t*>(b);
|
||||
|
||||
return intrin_mfma_f32_4x4x2bf16<MPerXdlops, NPerXdlops>::run(p_a, p_b, reg_c);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename base_type, index_t MPerXdlops, index_t NPerXdlops>
|
||||
struct MfmaSelector
|
||||
{
|
||||
@@ -498,73 +473,37 @@ struct MfmaSelector
|
||||
return MfmaInstr::mfma_f32_4x4x4f16;
|
||||
}
|
||||
|
||||
#if 0
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 128, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 2, 1, c_vec32_4_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 64, 128>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 2, c_vec32_4_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 64, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 64, 1, 1, c_vec32_2_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 64, 32>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 64, 32, 1, 1, c_vec32_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 32, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x2bf16, 32, 64, 1, 1, c_vec32_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 64, 16>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 64, 16, 1, 1, c_vec16_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 16, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_16x16x2bf16, 16, 64, 1, 1, c_vec16_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 8, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 8, 64, 1, 1, c_vec4_2_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 4, 64>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_4x4x2bf16, 4, 64, 1, 1, c_vec4_1_t>{};
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 32, 32>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_32x32x4bf16, 32, 32, 1, 1, c_vec16_1_t>{};
|
||||
#if defined(CK_AMD_GPU_GFX90A)
|
||||
return MfmaInstr::mfma_f32_32x32x8bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_32x32x4bf16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<ushort, 16, 16>()
|
||||
{
|
||||
return xdlops_info<MfmaInstr::mfma_f32_16x16x8bf16, 16, 16, 1, 1, c_vec4_1_t>{};
|
||||
}
|
||||
#if defined(CK_AMD_GPU_GFX90A)
|
||||
return MfmaInstr::mfma_f32_16x16x16bf16_1k;
|
||||
#else
|
||||
return MfmaInstr::mfma_f32_16x16x8bf16;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 32, 32>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_32x32x8i8;
|
||||
}
|
||||
|
||||
template <>
|
||||
static constexpr auto GetMfma<int8_t, 16, 16>()
|
||||
{
|
||||
return MfmaInstr::mfma_i32_16x16x16i8;
|
||||
}
|
||||
|
||||
static constexpr auto selected_mfma = mfma_type<GetMfma<base_type, MPerXdlops, NPerXdlops>()>{};
|
||||
|
||||
@@ -686,8 +625,8 @@ struct XdlopsGemm
|
||||
__device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const
|
||||
{
|
||||
static_assert(is_same<base_type, float>::value || is_same<base_type, half_t>::value ||
|
||||
is_same<base_type, ushort>::value,
|
||||
"base base_type must be float, half, ushort!");
|
||||
is_same<base_type, ushort>::value || is_same<base_type, int8_t>::value,
|
||||
"base base_type must be float, half, ushort, and int8_t!");
|
||||
|
||||
static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) {
|
||||
mfma_instr.template run<MPerXdlops, NPerXdlops>(p_a_wave[k], p_b_wave[k], p_c_thread);
|
||||
|
||||
@@ -50,11 +50,24 @@ llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
|
||||
__device__ int16_t
|
||||
__device__ ushort
|
||||
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i16");
|
||||
|
||||
__device__ ushort2_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i16");
|
||||
|
||||
__device__ ushort4_t
|
||||
llvm_amdgcn_raw_buffer_load_i16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i16");
|
||||
|
||||
__device__ int32_t
|
||||
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
@@ -133,12 +146,26 @@ llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
llvm_amdgcn_raw_buffer_store_i16(ushort vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i16x2(ushort2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i16x4(ushort4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
@@ -228,6 +255,7 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
(is_same<T, double>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
@@ -326,6 +354,31 @@ __device__ typename vector_type<T, N>::type amd_buffer_load_impl(int32x4_t src_w
|
||||
return as_type<half8_t>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, ushort>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return llvm_amdgcn_raw_buffer_load_i16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<ushort8_t>(tmp);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
@@ -458,6 +511,7 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
(is_same<T, double>::value && (N == 1 || N == 2)) ||
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, ushort>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)),
|
||||
"wrong! not implemented");
|
||||
@@ -552,6 +606,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type<T, N>::type src
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, ushort>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i16(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<half_t, 8> tmp{src_thread_data};
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
namespace ck {
|
||||
|
||||
// A, B, C, cbsz, abid, blgp
|
||||
// fp32
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x1f32(
|
||||
float, float, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x1f32");
|
||||
|
||||
@@ -21,6 +22,7 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x1f32(
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x1f32(
|
||||
float, float, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x1f32");
|
||||
|
||||
// fp16
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x4f16(
|
||||
half4_t, half4_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x4f16");
|
||||
|
||||
@@ -36,6 +38,13 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x4f16(
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x4f16(
|
||||
half4_t, half4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x4f16");
|
||||
|
||||
// bfp16
|
||||
extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
ushort4_t, ushort4_t, float16_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x8bf16.1k");
|
||||
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
ushort4_t, ushort4_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.16x16x16bf16.1k");
|
||||
|
||||
extern "C" __device__ float32_t llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(
|
||||
ushort2_t, ushort2_t, float32_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.32x32x2bf16");
|
||||
|
||||
@@ -51,6 +60,23 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(
|
||||
extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(
|
||||
ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16");
|
||||
|
||||
// int8
|
||||
extern "C" __device__ int32x32_t llvm_intrin_amdgcn_mfma_i32_32x32x4i8(
|
||||
int, int, int32x32_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x4i8");
|
||||
|
||||
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_16x16x4i8(
|
||||
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x4i8");
|
||||
|
||||
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_4x4x4i8(
|
||||
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.4x4x4i8");
|
||||
|
||||
extern "C" __device__ int32x16_t llvm_intrin_amdgcn_mfma_i32_32x32x8i8(
|
||||
int, int, int32x16_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.32x32x8i8");
|
||||
|
||||
extern "C" __device__ int32x4_t llvm_intrin_amdgcn_mfma_i32_16x16x16i8(
|
||||
int, int, int32x4_t, int, int, int) __asm("llvm.amdgcn.mfma.i32.16x16x16i8");
|
||||
|
||||
// fp32
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x1f32;
|
||||
|
||||
@@ -148,6 +174,7 @@ struct intrin_mfma_f32_4x4x1f32<8, 64>
|
||||
}
|
||||
};
|
||||
|
||||
// fp16
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x4f16;
|
||||
|
||||
@@ -244,147 +271,102 @@ struct intrin_mfma_f32_4x4x4f16<8, 64>
|
||||
}
|
||||
};
|
||||
|
||||
#if 0
|
||||
template <index_t MPerWave, index_t NPerWave, index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16;
|
||||
// bfp16
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_32x32x8bf16_1k;
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<128, 64, AStride, BStride>
|
||||
template <>
|
||||
struct intrin_mfma_f32_32x32x8bf16_1k<32, 32>
|
||||
{
|
||||
__device__ static c_vec32_4_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
reg_c.s.z =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.z, 1, 0, 0);
|
||||
reg_c.s.w =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[AStride], reg_b[0], reg_c.s.w, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x8bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 128, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_4_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_4_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
reg_c.s.z =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.z, 1, 0, 0);
|
||||
reg_c.s.w =
|
||||
llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[BStride], reg_c.s.w, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 64, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_2_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_2_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 1, 1, 0);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<64, 32, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 1);
|
||||
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t AStride, index_t BStride>
|
||||
struct intrin_mfma_f32_32x32x2bf16<32, 64, AStride, BStride>
|
||||
{
|
||||
__device__ static c_vec32_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec32_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 1, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
};
|
||||
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_32x32x4bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
|
||||
__device__ c_vec4_1_t::VecType intrin_mfma_f32_16x16x8bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec4_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c);
|
||||
template <>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 2, 0, 0);
|
||||
return reg_c;
|
||||
}
|
||||
struct intrin_mfma_f32_16x16x16bf16_1k;
|
||||
|
||||
template <>
|
||||
__device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<64, 16>(const ushort2_t* reg_a,
|
||||
const ushort2_t* reg_b,
|
||||
c_vec16_1_t::VecType reg_c)
|
||||
struct intrin_mfma_f32_16x16x16bf16_1k<16, 16>
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_16x16x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 0, 0, 4);
|
||||
return reg_c;
|
||||
}
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort4_t& reg_a, const ushort4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_f32_16x16x16bf16_1k(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_4x4x2bf16;
|
||||
struct intrin_mfma_f32_32x32x4bf16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x2bf16<4, 64>
|
||||
struct intrin_mfma_f32_32x32x4bf16<32, 32>
|
||||
{
|
||||
__device__ static c_vec4_1_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_1_t::VecType reg_c)
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
|
||||
return reg_c;
|
||||
reg_c.template AsType<float16_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float16_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_f32_16x16x8bf16;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_f32_4x4x2bf16<8, 64>
|
||||
struct intrin_mfma_f32_16x16x8bf16<16, 16>
|
||||
{
|
||||
__device__ static c_vec4_2_t::VecType
|
||||
run(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec4_2_t::VecType reg_c)
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const ushort2_t& reg_a, const ushort2_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.s.x = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.x, 4, 0, 0);
|
||||
reg_c.s.y = llvm_intrin_amdgcn_mfma_f32_4x4x2bf16(reg_a[0], reg_b[0], reg_c.s.y, 4, 1, 0);
|
||||
return reg_c;
|
||||
reg_c.template AsType<float4_t>()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x8bf16(
|
||||
reg_a, reg_b, reg_c.template AsType<float4_t>()[Number<0>{}], 0, 0, 0);
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_i32_32x32x8i8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_i32_32x32x8i8<32, 32>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x16_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_i32_32x32x8i8(as_type<int>(reg_a),
|
||||
as_type<int>(reg_b),
|
||||
reg_c.template AsType<int32x16_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t MPerWave, index_t NPerWave>
|
||||
struct intrin_mfma_i32_16x16x16i8;
|
||||
|
||||
template <>
|
||||
struct intrin_mfma_i32_16x16x16i8<16, 16>
|
||||
{
|
||||
template <class FloatC>
|
||||
__device__ static void Run(const int8x4_t& reg_a, const int8x4_t& reg_b, FloatC& reg_c)
|
||||
{
|
||||
reg_c.template AsType<int32x4_t>()(Number<0>{}) =
|
||||
llvm_intrin_amdgcn_mfma_i32_16x16x16i8(as_type<int>(reg_a),
|
||||
as_type<int>(reg_b),
|
||||
reg_c.template AsType<int32x4_t>()[Number<0>{}],
|
||||
0,
|
||||
0,
|
||||
0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
2
external/rocm/include/bfloat16_dev.hpp
vendored
2
external/rocm/include/bfloat16_dev.hpp
vendored
@@ -31,7 +31,7 @@ extern "C" {
|
||||
#endif
|
||||
|
||||
#ifdef __HIP_PLATFORM_HCC__
|
||||
#define EXECUTION_SPECIFIER __device__
|
||||
#define EXECUTION_SPECIFIER __device__ __host__
|
||||
#else
|
||||
#define EXECUTION_SPECIFIER
|
||||
#endif // MIOPEN_BACKEND_HIP
|
||||
|
||||
@@ -325,30 +325,30 @@ int main(int argc, char* argv[])
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
out.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_3<out_data_t>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
out.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
return GeneratorTensor_2<in_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
@@ -80,13 +80,29 @@ void host_convolution_forward(const Tensor<TIn>& in,
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[3])
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei(k, c, y, x));
|
||||
if constexpr(is_same<TIn, ushort>::value)
|
||||
{
|
||||
v += bfloat16_to_float(in(n, c, hi, wi)) *
|
||||
bfloat16_to_float(wei(k, c, y, x));
|
||||
}
|
||||
else
|
||||
{
|
||||
v += static_cast<const double>(in(n, c, hi, wi)) *
|
||||
static_cast<const double>(wei(k, c, y, x));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, k, ho, wo) = v;
|
||||
|
||||
if constexpr(is_same<TOut, ushort>::value)
|
||||
{
|
||||
out(n, k, ho, wo) = float_to_bfloat16(v);
|
||||
}
|
||||
else
|
||||
{
|
||||
out(n, k, ho, wo) = v;
|
||||
}
|
||||
};
|
||||
|
||||
auto f_nhwc = [&](auto n, auto ho, auto wo, auto k) {
|
||||
@@ -102,13 +118,28 @@ void host_convolution_forward(const Tensor<TIn>& in,
|
||||
if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 &&
|
||||
wi < in.mDesc.GetLengths()[2])
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(wei(k, y, x, c));
|
||||
if constexpr(is_same<TIn, ushort>::value)
|
||||
{
|
||||
v += bfloat16_to_float(in(n, hi, wi, c)) *
|
||||
bfloat16_to_float(wei(k, y, x, c));
|
||||
}
|
||||
else
|
||||
{
|
||||
v += static_cast<const double>(in(n, hi, wi, c)) *
|
||||
static_cast<const double>(wei(k, y, x, c));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
out(n, ho, wo, k) = v;
|
||||
if constexpr(is_same<TOut, ushort>::value)
|
||||
{
|
||||
out(n, ho, wo, k) = float_to_bfloat16(v);
|
||||
}
|
||||
else
|
||||
{
|
||||
out(n, ho, wo, k) = v;
|
||||
}
|
||||
};
|
||||
|
||||
if(layout == ConvTensorLayout::NCHW)
|
||||
@@ -226,10 +257,14 @@ int main(int argc, char* argv[])
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
#elif 0
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
using in_data_t = ushort;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = ushort;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
@@ -295,30 +330,30 @@ int main(int argc, char* argv[])
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_3<in_data_t>{0.0, 1.0}, num_thread);
|
||||
wei.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.5, 0.5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{1, 5}, num_thread);
|
||||
|
||||
auto gen_wei = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
return GeneratorTensor_2<in_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
wei.GenerateTensorValue(gen_wei, num_thread);
|
||||
}
|
||||
|
||||
@@ -297,30 +297,30 @@ int main(int argc, char* argv[])
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
in.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_1<in_data_t>{}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_1<out_data_t>{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{-5, 5}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_2<out_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 5:
|
||||
in.GenerateTensorValue(GeneratorTensor_3<float>{-0.1, 0.1}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_3<float>{-0.1, 0.1}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_3<in_data_t>{-0.1, 0.1}, num_thread);
|
||||
out.GenerateTensorValue(GeneratorTensor_3<out_data_t>{-0.1, 0.1}, num_thread);
|
||||
break;
|
||||
default:
|
||||
in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread);
|
||||
in.GenerateTensorValue(GeneratorTensor_2<in_data_t>{1, 5}, num_thread);
|
||||
|
||||
auto gen_out = [](auto... is) {
|
||||
return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
return GeneratorTensor_2<out_data_t>{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...);
|
||||
};
|
||||
out.GenerateTensorValue(gen_out, num_thread);
|
||||
}
|
||||
|
||||
@@ -239,10 +239,14 @@ int main(int argc, char* argv[])
|
||||
using ab_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = float;
|
||||
#elif 1
|
||||
#elif 0
|
||||
using ab_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = half_t;
|
||||
#elif 1
|
||||
using ab_data_t = ushort;
|
||||
using acc_data_t = float;
|
||||
using c_data_t = ushort;
|
||||
#elif 1
|
||||
using ab_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
@@ -321,24 +325,24 @@ int main(int argc, char* argv[])
|
||||
// no initialization
|
||||
break;
|
||||
case 1:
|
||||
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
|
||||
break;
|
||||
case 2:
|
||||
a.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
a.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
case 3:
|
||||
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1{}, num_thread);
|
||||
a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_1<ab_data_t>{}, num_thread);
|
||||
break;
|
||||
case 4:
|
||||
a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread);
|
||||
a.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_2<ab_data_t>{-5, 5}, num_thread);
|
||||
break;
|
||||
default:
|
||||
a.GenerateTensorValue(GeneratorTensor_3<float>{0.0, 1.0}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_3<float>{-0.5, 0.5}, num_thread);
|
||||
a.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{0.0, 1.0}, num_thread);
|
||||
b.GenerateTensorValue(GeneratorTensor_3<ab_data_t>{-0.5, 0.5}, num_thread);
|
||||
}
|
||||
|
||||
#if USE_GEMM_XDL_MK_KN_MN
|
||||
|
||||
@@ -1,6 +1,162 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
|
||||
template <>
|
||||
void host_gemm<ushort, ushort, ushort>(const Tensor<ushort>& a,
|
||||
const Tensor<ushort>& b,
|
||||
Tensor<ushort>& c,
|
||||
const GemmMatrixLayout layout)
|
||||
{
|
||||
if(layout == GemmMatrixLayout::MK_KN_MN)
|
||||
{
|
||||
auto f_mk_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n));
|
||||
}
|
||||
|
||||
c(m, n) = float_to_bfloat16(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_MN)
|
||||
{
|
||||
auto f_mk_nk_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k));
|
||||
}
|
||||
|
||||
c(m, n) = float_to_bfloat16(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_MN)
|
||||
{
|
||||
auto f_km_kn_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n));
|
||||
}
|
||||
|
||||
c(m, n) = float_to_bfloat16(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_MN)
|
||||
{
|
||||
auto f_km_nk_mn = [&](auto m, auto n) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k));
|
||||
}
|
||||
|
||||
c(m, n) = float_to_bfloat16(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_KN_NM)
|
||||
{
|
||||
auto f_mk_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = float_to_bfloat16(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::MK_NK_NM)
|
||||
{
|
||||
auto f_mk_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[1];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += bfloat16_to_float(a(m, k)) * bfloat16_to_float(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = float_to_bfloat16(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_KN_NM)
|
||||
{
|
||||
auto f_km_kn_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(k, n));
|
||||
}
|
||||
|
||||
c(n, m) = float_to_bfloat16(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else if(layout == GemmMatrixLayout::KM_NK_NM)
|
||||
{
|
||||
auto f_km_nk_nm = [&](auto n, auto m) {
|
||||
const int K = a.mDesc.GetLengths()[0];
|
||||
|
||||
double v = 0;
|
||||
|
||||
for(int k = 0; k < K; ++k)
|
||||
{
|
||||
v += bfloat16_to_float(a(k, m)) * bfloat16_to_float(b(n, k));
|
||||
}
|
||||
|
||||
c(n, m) = float_to_bfloat16(v);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])(
|
||||
std::thread::hardware_concurrency());
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! not supported layout");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename AType, typename BType, typename CType>
|
||||
void host_gemm_mk_kn_mn(const Tensor<AType>& a_m_k,
|
||||
const Tensor<BType>& b_k_n,
|
||||
|
||||
@@ -321,4 +321,41 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
|
||||
}
|
||||
|
||||
float bf16_to_f32(ushort src_val)
|
||||
{
|
||||
typedef union
|
||||
{
|
||||
ushort x, y;
|
||||
float f32;
|
||||
} bf16_f32_t;
|
||||
|
||||
bf16_f32_t v;
|
||||
v.x = 0;
|
||||
v.y = src_val;
|
||||
return v.f32;
|
||||
}
|
||||
|
||||
template <>
|
||||
void check_error<ushort>(const Tensor<ushort>& ref, const Tensor<ushort>& result)
|
||||
{
|
||||
float error = 0;
|
||||
float max_diff = -1;
|
||||
float ref_value = 0, result_value = 0;
|
||||
for(int i = 0; i < ref.mData.size(); ++i)
|
||||
{
|
||||
error += std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
|
||||
float diff = std::abs(bf16_to_f32(ref.mData[i]) - bf16_to_f32(result.mData[i]));
|
||||
if(max_diff < diff)
|
||||
{
|
||||
max_diff = diff;
|
||||
ref_value = bf16_to_f32(ref.mData[i]);
|
||||
result_value = bf16_to_f32(result.mData[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "error: " << error << std::endl;
|
||||
std::cout << "max_diff: " << max_diff << ", ref: " << ref_value << ", res: " << result_value
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include <cmath>
|
||||
#include "config.hpp"
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_1
|
||||
{
|
||||
int value = 1;
|
||||
@@ -15,6 +16,30 @@ struct GeneratorTensor_1
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<ushort>
|
||||
{
|
||||
float value = 1.0;
|
||||
|
||||
template <typename... Is>
|
||||
ushort operator()(Is...)
|
||||
{
|
||||
return float_to_bfloat16(value);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_1<int8_t>
|
||||
{
|
||||
int8_t value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
int8_t operator()(Is...)
|
||||
{
|
||||
return value;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_0
|
||||
{
|
||||
int value = 0;
|
||||
@@ -26,6 +51,7 @@ struct GeneratorTensor_0
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_2
|
||||
{
|
||||
int min_value = 0;
|
||||
@@ -38,6 +64,33 @@ struct GeneratorTensor_2
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<ushort>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ushort operator()(Is...)
|
||||
{
|
||||
float tmp = (std::rand() % (max_value - min_value)) + min_value;
|
||||
return float_to_bfloat16(tmp);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_2<int8_t>
|
||||
{
|
||||
int min_value = 0;
|
||||
int max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
int8_t operator()(Is...)
|
||||
{
|
||||
return (std::rand() % (max_value - min_value)) + min_value;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GeneratorTensor_3
|
||||
{
|
||||
@@ -53,6 +106,39 @@ struct GeneratorTensor_3
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<ushort>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
ushort operator()(Is...)
|
||||
{
|
||||
float tmp = float(std::rand()) / float(RAND_MAX);
|
||||
|
||||
float fp32_tmp = min_value + tmp * (max_value - min_value);
|
||||
|
||||
return float_to_bfloat16(fp32_tmp);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GeneratorTensor_3<int8_t>
|
||||
{
|
||||
float min_value = 0;
|
||||
float max_value = 1;
|
||||
|
||||
template <typename... Is>
|
||||
int8_t operator()(Is...)
|
||||
{
|
||||
int8_t min_tmp = static_cast<int8_t>(min_value);
|
||||
int8_t max_tmp = static_cast<int8_t>(max_value);
|
||||
|
||||
return (std::rand() % (max_tmp - min_tmp)) + min_tmp;
|
||||
}
|
||||
};
|
||||
|
||||
struct GeneratorTensor_Checkboard
|
||||
{
|
||||
template <typename... Ts>
|
||||
|
||||
Reference in New Issue
Block a user