mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Hybrid direct + implicit GEMM forward convolution NCHWc v5r1 (#25)
* Hybrid direct + implicit GEMM forward convolution NCHWc v5r1. Input tensor bypass LDS. Support fp32/fp16/int8
This commit is contained in:
@@ -38,7 +38,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
typename InRightPads>
|
||||
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
@@ -51,17 +51,21 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
@@ -78,7 +82,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gemmk_gemmm_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -104,7 +108,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_gemmk_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
@@ -114,31 +118,31 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_gemmm_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo)),
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
if(!(K % KPerBlock == 0 && Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0 &&
|
||||
E % EPerBlock == 0))
|
||||
if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_k_m_global_iterator_hacks =
|
||||
constexpr auto a_e_k_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_k_m_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_k_n_global_iterator_hacks =
|
||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
@@ -148,17 +152,17 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_k_n_global_move_slice_window_iterator_hack =
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_h_w_global_tensor_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{},
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
@@ -171,9 +175,9 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
decltype(in_gemmk_n_ho_wo_global_desc),
|
||||
decltype(out_gemmm_n_ho_wo_global_desc),
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
@@ -196,13 +200,13 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_k_m_global_iterator_hacks),
|
||||
decltype(b_k_n_global_iterator_hacks),
|
||||
decltype(c_k_n_h_w_global_tensor_iterator_hacks),
|
||||
decltype(a_k_m_global_move_slice_window_iterator_hack),
|
||||
decltype(b_k_n_global_move_slice_window_iterator_hack)>;
|
||||
decltype(a_e_k_global_iterator_hacks),
|
||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
||||
|
||||
@@ -226,108 +230,104 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_gemmm_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_gemmk_n_ho_wo_global_desc,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_gemmm_n_ho_wo_global_desc,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_gemmm_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_gemmk_n_ho_wo_global_desc,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_gemmm_n_ho_wo_global_desc,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_gemmm_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_gemmk_n_ho_wo_global_desc,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_gemmm_n_ho_wo_global_desc,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_gemmk_gemmm_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_gemmk_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_gemmm_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_gemmk_gemmm_global_desc,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_gemmk_n_ho_wo_global_desc,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_gemmm_n_ho_wo_global_desc,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
@@ -340,7 +340,7 @@ struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
|
||||
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k_ho_wo_global_desc) /
|
||||
out_n_k0_ho_wo_k1_global_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
|
||||
@@ -0,0 +1,368 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
index_t KPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
index_t ABlockTransferSrcScalarPerVector_E,
|
||||
index_t ABlockTransferDstScalarPerVector_K,
|
||||
index_t BThreadTransferSrcScalarPerVector_W,
|
||||
index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
|
||||
<< std::endl;
|
||||
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
|
||||
<< std::endl;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, 0, OutRightPadH),
|
||||
make_pad_transform(Wo, 0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_iterator_hacks),
|
||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -134,9 +134,12 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
|
||||
constexpr auto KPerThreadSubC = 4;
|
||||
|
||||
constexpr auto HoPerThreadSubC = 2;
|
||||
constexpr auto WoPerThreadSubC = 2;
|
||||
|
||||
static_assert(KPerThread % KPerThreadSubC == 0, "");
|
||||
static_assert(HPerThread % 2 == 0, "");
|
||||
static_assert(WPerThread % 2 == 0, "");
|
||||
static_assert(HPerThread % HoPerThreadSubC == 0, "");
|
||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
@@ -158,7 +161,9 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<decltype(a_thread_mtx),
|
||||
decltype(b_thread_mtx),
|
||||
decltype(c_thread_mtx)>{};
|
||||
decltype(c_thread_mtx),
|
||||
HoPerThreadSubC,
|
||||
WoPerThreadSubC>{};
|
||||
// loop over k
|
||||
#pragma unroll
|
||||
for(index_t e_begin = 0; e_begin < EPerBlock; e_begin += EPerThreadLoop)
|
||||
@@ -171,10 +176,11 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
mMyThreadOffsetA,
|
||||
p_a_thread);
|
||||
|
||||
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += 2)
|
||||
#pragma unroll
|
||||
for(index_t h_begin = 0; h_begin < HPerThread; h_begin += HoPerThreadSubC)
|
||||
{
|
||||
|
||||
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += 2)
|
||||
#pragma unroll
|
||||
for(index_t w_begin = 0; w_begin < WPerThread; w_begin += WoPerThreadSubC)
|
||||
{
|
||||
threadwise_gemm.Run(p_a_thread,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(make_tuple(
|
||||
|
||||
@@ -37,6 +37,8 @@ __device__ void threadwise_matrix_set_zero_v3(Desc, Float* __restrict__ p_thread
|
||||
template <typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
index_t H,
|
||||
index_t W,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
@@ -54,11 +56,6 @@ struct ThreadwiseGemm_km_kn_mn_v3
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
// constexpr auto H = BDesc{}.GetLength(I2);
|
||||
// constexpr auto W = BDesc{}.GetLength(I3);
|
||||
constexpr auto H = 2;
|
||||
constexpr auto W = 2;
|
||||
|
||||
constexpr auto E = ADesc{}.GetLength(I0);
|
||||
constexpr auto K = ADesc{}.GetLength(I1);
|
||||
|
||||
|
||||
@@ -59,7 +59,26 @@ __llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
// half
|
||||
__device__ half_t
|
||||
__llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
|
||||
|
||||
__device__ half2_t
|
||||
__llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
|
||||
|
||||
__device__ half4_t
|
||||
__llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
|
||||
|
||||
// float
|
||||
__device__ float
|
||||
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
@@ -114,6 +133,28 @@ __llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
// half
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
|
||||
// float
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp32(float vdata,
|
||||
int32x4_t rsrc,
|
||||
@@ -142,7 +183,13 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
index_t src_wave_addr_offset)
|
||||
{
|
||||
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, half2_t>::value && (N == 1)) ||
|
||||
(is_same<T, half4_t>::value && (N == 1)) ||
|
||||
(is_same<T, half8_t>::value && (N == 1)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)) ||
|
||||
(is_same<T, int32x2_t>::value && (N == 1)) ||
|
||||
(is_same<T, int32x4_t>::value && (N == 1)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, float>::value)
|
||||
@@ -169,8 +216,63 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(float), 0);
|
||||
tmp.Vectors(Number<4>{})(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
0);
|
||||
|
||||
return tmp.Vector();
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half2_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half4_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half8_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
vector_type<half_t, 8> tmp;
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
|
||||
return tmp.Vector();
|
||||
}
|
||||
@@ -199,12 +301,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
tmp.Vectors(Number<4>{})(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.Vectors(Number<4>{})(Number<1>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, 4 * sizeof(int32_t), 0);
|
||||
tmp.Vectors(Number<4>{})(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
0);
|
||||
|
||||
return tmp.Vector();
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32x2_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32x4_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, index_t N>
|
||||
@@ -213,10 +334,12 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
index_t dst_thread_addr_offset,
|
||||
index_t dst_wave_addr_offset)
|
||||
{
|
||||
static_assert((is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4)),
|
||||
"wrong! not implemented");
|
||||
static_assert(
|
||||
(is_same<T, float>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int32_t>::value && (N == 1 || N == 2 || N == 4)) ||
|
||||
(is_same<T, int8_t>::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) ||
|
||||
(is_same<T, half_t>::value && (N == 1 || N == 2 || N == 4 || N == 8)),
|
||||
"wrong! not implemented");
|
||||
|
||||
if constexpr(is_same<T, float>::value)
|
||||
{
|
||||
@@ -298,6 +421,65 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<half_t, 8> tmp{src_thread_data};
|
||||
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.Vectors(Number<4>{})[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -166,6 +166,53 @@ __device__ void amd_assembly_outer_product_1x4(half4_t a,
|
||||
"3"(c3));
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_outer_product_1x4(half8_t a,
|
||||
half8_t b0,
|
||||
half8_t b1,
|
||||
half8_t b2,
|
||||
half8_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
|
||||
const half4_t* p_a_half4 = reinterpret_cast<const half4_t*>(&a);
|
||||
const half4_t* p_b0_half4 = reinterpret_cast<const half4_t*>(&b0);
|
||||
const half4_t* p_b1_half4 = reinterpret_cast<const half4_t*>(&b1);
|
||||
const half4_t* p_b2_half4 = reinterpret_cast<const half4_t*>(&b2);
|
||||
const half4_t* p_b3_half4 = reinterpret_cast<const half4_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half4[0], p_b0_half4[0], p_b1_half4[0], p_b2_half4[0], p_b3_half4[0], c0, c1, c2, c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half4[1], p_b0_half4[1], p_b1_half4[1], p_b2_half4[1], p_b3_half4[1], c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_outer_product_1x4(half16_t a,
|
||||
half16_t b0,
|
||||
half16_t b1,
|
||||
half16_t b2,
|
||||
half16_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
const half8_t* p_a_half8 = reinterpret_cast<const half8_t*>(&a);
|
||||
const half8_t* p_b0_half8 = reinterpret_cast<const half8_t*>(&b0);
|
||||
const half8_t* p_b1_half8 = reinterpret_cast<const half8_t*>(&b1);
|
||||
const half8_t* p_b2_half8 = reinterpret_cast<const half8_t*>(&b2);
|
||||
const half8_t* p_b3_half8 = reinterpret_cast<const half8_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half8[0], p_b0_half8[0], p_b1_half8[0], p_b2_half8[0], p_b3_half8[0], c0, c1, c2, c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(
|
||||
p_a_half8[1], p_b0_half8[1], p_b1_half8[1], p_b2_half8[1], p_b3_half8[1], c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
__device__ void
|
||||
@@ -215,5 +262,82 @@ __device__ void amd_assembly_outer_product_1x4(int8x4_t a,
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_outer_product_1x4(int8x8_t a,
|
||||
int8x8_t b0,
|
||||
int8x8_t b1,
|
||||
int8x8_t b2,
|
||||
int8x8_t b3,
|
||||
int32_t& c0,
|
||||
int32_t& c1,
|
||||
int32_t& c2,
|
||||
int32_t& c3)
|
||||
{
|
||||
|
||||
const int8x4_t* p_a_int8x4_t = reinterpret_cast<const int8x4_t*>(&a);
|
||||
const int8x4_t* p_b0_int8x4_t = reinterpret_cast<const int8x4_t*>(&b0);
|
||||
const int8x4_t* p_b1_int8x4_t = reinterpret_cast<const int8x4_t*>(&b1);
|
||||
const int8x4_t* p_b2_int8x4_t = reinterpret_cast<const int8x4_t*>(&b2);
|
||||
const int8x4_t* p_b3_int8x4_t = reinterpret_cast<const int8x4_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a_int8x4_t[0],
|
||||
p_b0_int8x4_t[0],
|
||||
p_b1_int8x4_t[0],
|
||||
p_b2_int8x4_t[0],
|
||||
p_b3_int8x4_t[0],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a_int8x4_t[1],
|
||||
p_b0_int8x4_t[1],
|
||||
p_b1_int8x4_t[1],
|
||||
p_b2_int8x4_t[1],
|
||||
p_b3_int8x4_t[1],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_outer_product_1x4(int8x16_t a,
|
||||
int8x16_t b0,
|
||||
int8x16_t b1,
|
||||
int8x16_t b2,
|
||||
int8x16_t b3,
|
||||
int32_t& c0,
|
||||
int32_t& c1,
|
||||
int32_t& c2,
|
||||
int32_t& c3)
|
||||
|
||||
{
|
||||
|
||||
const int8x8_t* p_a_int8x8_t = reinterpret_cast<const int8x8_t*>(&a);
|
||||
const int8x8_t* p_b0_int8x8_t = reinterpret_cast<const int8x8_t*>(&b0);
|
||||
const int8x8_t* p_b1_int8x8_t = reinterpret_cast<const int8x8_t*>(&b1);
|
||||
const int8x8_t* p_b2_int8x8_t = reinterpret_cast<const int8x8_t*>(&b2);
|
||||
const int8x8_t* p_b3_int8x8_t = reinterpret_cast<const int8x8_t*>(&b3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a_int8x8_t[0],
|
||||
p_b0_int8x8_t[0],
|
||||
p_b1_int8x8_t[0],
|
||||
p_b2_int8x8_t[0],
|
||||
p_b3_int8x8_t[0],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
|
||||
amd_assembly_outer_product_1x4(p_a_int8x8_t[1],
|
||||
p_b0_int8x8_t[1],
|
||||
p_b1_int8x8_t[1],
|
||||
p_b2_int8x8_t[1],
|
||||
p_b3_int8x8_t[1],
|
||||
c0,
|
||||
c1,
|
||||
c2,
|
||||
c3);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -168,6 +168,84 @@ struct vector_type<T, 8>
|
||||
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d16_t __attribute__((ext_vector_type(16)));
|
||||
|
||||
using type = d16_t;
|
||||
|
||||
union
|
||||
{
|
||||
d16_t d16_;
|
||||
StaticallyIndexedArray<d1_t, 16> d1x16_;
|
||||
StaticallyIndexedArray<d2_t, 8> d2x8_;
|
||||
StaticallyIndexedArray<d4_t, 4> d4x4_;
|
||||
StaticallyIndexedArray<d8_t, 2> d8x2_;
|
||||
StaticallyIndexedArray<d16_t, 1> d16x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 16; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
|
||||
};
|
||||
|
||||
// fp32
|
||||
using float2_t = typename vector_type<float, 2>::type;
|
||||
using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
|
||||
// fp16
|
||||
using half_t = _Float16;
|
||||
using half2_t = typename vector_type<half_t, 2>::type;
|
||||
using half4_t = typename vector_type<half_t, 4>::type;
|
||||
using half8_t = typename vector_type<half_t, 8>::type;
|
||||
using half16_t = typename vector_type<half_t, 16>::type;
|
||||
|
||||
// bfp16
|
||||
using ushort2_t = typename vector_type<ushort, 2>::type;
|
||||
using ushort4_t = typename vector_type<ushort, 4>::type;
|
||||
using ushort8_t = typename vector_type<ushort, 8>::type;
|
||||
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
|
||||
template <>
|
||||
struct vector_type<int8_t, 2>
|
||||
{
|
||||
@@ -250,31 +328,118 @@ struct vector_type<int8_t, 4>
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x1_; }
|
||||
};
|
||||
|
||||
// fp32
|
||||
using float2_t = typename vector_type<float, 2>::type;
|
||||
using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
template <>
|
||||
struct vector_type<int8_t, 8>
|
||||
{
|
||||
using d1_t = int8_t;
|
||||
typedef int16_t d2_t;
|
||||
typedef int32_t d4_t;
|
||||
typedef int32x2_t d8_t;
|
||||
|
||||
// fp16
|
||||
using half_t = _Float16;
|
||||
using half2_t = typename vector_type<half_t, 2>::type;
|
||||
using half4_t = typename vector_type<half_t, 4>::type;
|
||||
using half8_t = typename vector_type<half_t, 8>::type;
|
||||
using type = d8_t;
|
||||
|
||||
// bfp16
|
||||
using ushort2_t = typename vector_type<ushort, 2>::type;
|
||||
using ushort4_t = typename vector_type<ushort, 4>::type;
|
||||
using ushort8_t = typename vector_type<ushort, 8>::type;
|
||||
union
|
||||
{
|
||||
d8_t d8_;
|
||||
StaticallyIndexedArray<d1_t, 8> d1x8_;
|
||||
StaticallyIndexedArray<d2_t, 4> d2x4_;
|
||||
StaticallyIndexedArray<d4_t, 2> d4x2_;
|
||||
StaticallyIndexedArray<d8_t, 1> d8x1_;
|
||||
} data_;
|
||||
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 8; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x1_; }
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<int8_t, 16>
|
||||
{
|
||||
using d1_t = int8_t;
|
||||
typedef int16_t d2_t;
|
||||
typedef int32_t d4_t;
|
||||
typedef int32x2_t d8_t;
|
||||
typedef int32x4_t d16_t;
|
||||
|
||||
using type = d16_t;
|
||||
|
||||
union
|
||||
{
|
||||
d16_t d16_;
|
||||
StaticallyIndexedArray<d1_t, 16> d1x16_;
|
||||
StaticallyIndexedArray<d2_t, 8> d2x8_;
|
||||
StaticallyIndexedArray<d4_t, 4> d4x4_;
|
||||
StaticallyIndexedArray<d8_t, 2> d8x2_;
|
||||
StaticallyIndexedArray<d8_t, 1> d16x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
__host__ __device__ static constexpr index_t Size() { return 16; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vector() const { return data_.d16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vector() { return data_.d16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Scalars() const { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Scalars() { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<1>) const { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<2>) const { return data_.d2x8_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<4>) const { return data_.d4x4_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<8>) const { return data_.d8x2_; }
|
||||
|
||||
__host__ __device__ constexpr const auto& Vectors(Number<16>) const { return data_.d16x1_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<1>) { return data_.d1x16_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<2>) { return data_.d2x8_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<4>) { return data_.d4x4_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<8>) { return data_.d8x2_; }
|
||||
|
||||
__host__ __device__ constexpr auto& Vectors(Number<16>) { return data_.d16x1_; }
|
||||
};
|
||||
|
||||
// i8
|
||||
// hack for int8x4_t, because compiler does not have native support for int8x4_t
|
||||
// int8x4_t is defined as int32_t
|
||||
using int8x4_t = typename vector_type<int8_t, 4>::type;
|
||||
using int8x4_t = typename vector_type<int8_t, 4>::type;
|
||||
using int8x8_t = typename vector_type<int8_t, 8>::type;
|
||||
using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
@@ -339,6 +504,34 @@ struct inner_product_with_conversion
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(int8x8_t a, int8x8_t b) const
|
||||
{
|
||||
const vector_type<int8_t, 8> a_vector{a};
|
||||
const vector_type<int8_t, 8> b_vector{b};
|
||||
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(int8x16_t a, int8x16_t b) const
|
||||
{
|
||||
const vector_type<int8_t, 16> a_vector{a};
|
||||
const vector_type<int8_t, 16> b_vector{b};
|
||||
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, 16, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw_outpad.hpp"
|
||||
|
||||
template <class TInWei,
|
||||
ck::index_t InWeiVectorSize,
|
||||
@@ -57,6 +58,9 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
constexpr auto C0 = C / Number<InWeiVectorSize>{};
|
||||
constexpr auto C1 = Number<InWeiVectorSize>{};
|
||||
|
||||
constexpr auto K0 = K / Number<InWeiVectorSize>{};
|
||||
constexpr auto K1 = Number<InWeiVectorSize>{};
|
||||
|
||||
#if 0
|
||||
// run-time variables
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
@@ -76,19 +80,21 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, C0, Hi, Wi));
|
||||
const auto wei_k_c0_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C0, Y, X));
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho, Wo));
|
||||
const auto out_n_k0_ho_wo_k1_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1));
|
||||
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
const auto conv_strides = sequence_to_tuple_of_number(ConvStrides{});
|
||||
const auto conv_dilations = sequence_to_tuple_of_number(ConvDilations{});
|
||||
const auto in_left_pads = sequence_to_tuple_of_number(InLeftPads{});
|
||||
const auto in_right_pads = sequence_to_tuple_of_number(InRightPads{});
|
||||
#endif
|
||||
|
||||
Tensor<TInWei> in_n_c0_hi_wi_c1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<N, C0, Hi, Wi, C1>{})));
|
||||
Tensor<TInWei> wei_k_c0_y_x_c1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<K, C0, Y, X, C1>{})));
|
||||
Tensor<TOut> out_n_k0_ho_wo_k1(make_HostTensorDescriptor(
|
||||
make_native_tensor_descriptor_packed(Sequence<N, K0, Ho, Wo, K1>{})));
|
||||
|
||||
auto f_nchw2nc0hwc1 = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_n_c0_hi_wi_c1(n, c / InWeiVectorSize, hi, wi, c % InWeiVectorSize) =
|
||||
@@ -106,13 +112,38 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c0_hi_wi_c1.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c0_y_x_c1.mData.data());
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 64, 16x8x32x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = K;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = C0;
|
||||
|
||||
constexpr index_t KPerThread = KPerBlock;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
constexpr index_t EPerThread = EPerBlock;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_E_K = Sequence<3, 1>;
|
||||
using ABlockTransferThreadClusterLengths_E_K = Sequence<3 * EPerBlock, KPerBlock>;
|
||||
|
||||
constexpr index_t ABlockTransferSrcScalarPerVector_E = 1;
|
||||
constexpr index_t ABlockTransferDstScalarPerVector_K = 1;
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#else
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 16;
|
||||
constexpr index_t HoPerBlock = 8;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
constexpr index_t EPerBlock = 1;
|
||||
|
||||
constexpr index_t KPerThread = 16;
|
||||
constexpr index_t HoPerThread = 2;
|
||||
@@ -127,32 +158,28 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
|
||||
constexpr index_t BThreadTransferSrcScalarPerVector_W = 1;
|
||||
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = 1;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector_W = K1;
|
||||
|
||||
static_assert(KPerThread % CThreadTransferDstScalarPerVector_W == 0, "");
|
||||
#endif
|
||||
|
||||
constexpr auto conv_driver =
|
||||
#if 0
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad<
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type,
|
||||
TAcc,
|
||||
TOut,
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W>{};
|
||||
#else
|
||||
DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad<
|
||||
#endif
|
||||
BlockSize,
|
||||
typename vector_type<TInWei, InWeiVectorSize>::type, TAcc, TOut, KPerBlock,
|
||||
HoPerBlock, WoPerBlock, EPerBlock, KPerThread, HoPerThread, WoPerThread,
|
||||
EPerThread, ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K, ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K, BThreadTransferSrcScalarPerVector_W,
|
||||
CThreadTransferDstScalarPerVector_W > {};
|
||||
|
||||
conv_driver.Run(wei_k_c0_y_x_desc,
|
||||
in_n_c0_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
out_n_k0_ho_wo_k1_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
@@ -163,5 +190,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw(
|
||||
in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()));
|
||||
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k0_ho_wo_k1.mData.data());
|
||||
|
||||
auto f_nk0hwk1_to_nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_n_k_ho_wo(n, k, ho, wo) =
|
||||
out_n_k0_ho_wo_k1(n, k / InWeiVectorSize, ho, wo, k % InWeiVectorSize);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nk0hwk1_to_nkhw, N, K, Ho, Wo)();
|
||||
}
|
||||
|
||||
@@ -48,8 +48,8 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 0
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
@@ -62,8 +62,8 @@ int main(int argc, char* argv[])
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using LeftPads = Sequence<0, 0>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
#elif 1
|
||||
constexpr index_t N = 1;
|
||||
constexpr index_t C = 16;
|
||||
@@ -642,7 +642,7 @@ int main(int argc, char* argv[])
|
||||
using out_data_t = int8_t;
|
||||
#elif 1
|
||||
using in_data_t = int8_t;
|
||||
constexpr index_t in_vector_size = 4;
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
@@ -741,7 +741,7 @@ int main(int argc, char* argv[])
|
||||
LeftPads{},
|
||||
RightPads{},
|
||||
nrepeat);
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk<in_data_t,
|
||||
in_vector_size,
|
||||
acc_data_t,
|
||||
|
||||
Reference in New Issue
Block a user