mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
[MIOpen Downstream] Initial MIOpen integration (#52)
* update online kernel wrapper bundle all descriptors in a tuple * change __CONSTANT__ to CONSTANT * rename * adding tuning * added IsValidCompileParameter * reorginze * adding tunable for fp16 and int8 * fix kernel compile warning and bug fixes * suppress warning about cast CONSTANT (address space 4) pointer * fix building issue
This commit is contained in:
@@ -1,296 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP
|
||||
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_v1r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t GK0PerBlock,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
index_t BM10BN10ThreadClusterBM100,
|
||||
index_t BM10BN10ThreadClusterBN100,
|
||||
index_t BM10BN10ThreadClusterBM101,
|
||||
index_t BM10BN10ThreadClusterBN101,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float
|
||||
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGridDesc_GK0_GM0_GM1_GK1& a_grid_desc_gk0_gm0_gm1_gk1,
|
||||
const BGridDesc_GK0_GN0_GN1_GK1& b_grid_desc_gk0_gn0_gn1_gk1,
|
||||
const CGridDesc_GM0_GM1_GN0_GN1& c_grid_desc_gm0_gm1_gn0_gn1,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
GK0PerBlock,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm1_gk1.GetLength(I0);
|
||||
|
||||
if(!GridwiseContraction::CheckValidity(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1, b_grid_desc_gk0_gn0_gn1_gk1, c_grid_desc_gm0_gm1_gn0_gn1))
|
||||
{
|
||||
throw std::runtime_error("wrong! "
|
||||
"GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_"
|
||||
"GM0_GM1_GN0_GN1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
|
||||
using AGridDesc_GK0_GM0_GM10_GM11_GK1 = decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1);
|
||||
using BGridDesc_GK0_GN0_GN10_GN11_GK1 = decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1);
|
||||
|
||||
// c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1 = decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
|
||||
|
||||
// c_grid_block_cluster_blockid_to_gm10_gn10
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
using CGridBlockCluster_BlockId_To_GM10_GN10 =
|
||||
decltype(c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
|
||||
|
||||
const bool has_double_tail_k_block_loop =
|
||||
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
|
||||
|
||||
{
|
||||
std::cout << "a_grid_desc_gk0_gm0_gm10_gm11_gk1{"
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I1) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I2) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I3) << ", "
|
||||
<< a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_grid_desc_gk0_gn0_gn10_gn11_gk1{"
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I0) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I1) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I2) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I3) << ", "
|
||||
<< b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1{ "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I0) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I1) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I2) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I3) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I4) << ", "
|
||||
<< c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r2<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGridDesc_GK0_GM0_GM10_GM11_GK1>,
|
||||
remove_reference_t<BGridDesc_GK0_GN0_GN10_GN11_GK1>,
|
||||
remove_reference_t<CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1>,
|
||||
remove_reference_t<CGridBlockCluster_BlockId_To_GM10_GN10>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,353 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
index_t KPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
index_t ABlockTransferSrcScalarPerVector_E,
|
||||
index_t ABlockTransferDstScalarPerVector_K,
|
||||
index_t BThreadTransferSrcScalarPerVector_W,
|
||||
index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_pad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Ho),
|
||||
make_pass_through_transform(Wo)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Ho % HoPerBlock) == 0 && (Wo % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
#if 1
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_iterator_hacks),
|
||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Ho / HoPerBlock) * (Wo / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_ho_wo_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_ho_wo_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
};
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,368 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
#define CK_DRIVER_DYNAMIC_CONVOLUTION_FORWARD_IMPLICIT_GEMM_V5R1_NCHW_KCYX_NKHW_OUTPAD_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v2.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
index_t KPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t EPerThread,
|
||||
typename ABlockTransferThreadSliceLengths_E_K,
|
||||
typename ABlockTransferThreadClusterLengths_E_K,
|
||||
index_t ABlockTransferSrcScalarPerVector_E,
|
||||
index_t ABlockTransferDstScalarPerVector_K,
|
||||
index_t BThreadTransferSrcScalarPerVector_W,
|
||||
index_t CThreadTransferDstScalarPerVector_W>
|
||||
struct DriverDynamicConvolutionForwardImplicitGemm_v5r1_nchw_kcyx_nkhw_outpad
|
||||
{
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
__host__ void Run(const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_global_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_global_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k0_ho_wo_k1_global_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const FloatAB* __restrict__ p_wei_global,
|
||||
const FloatAB* __restrict__ p_in_global,
|
||||
FloatC* __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
const auto K0 = out_n_k0_ho_wo_k1_global_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k0_ho_wo_k1_global_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k0_ho_wo_k1_global_desc.GetLength(I3);
|
||||
|
||||
const auto K1 = out_n_k0_ho_wo_k1_global_desc.GetLength(I4);
|
||||
|
||||
const auto K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
const auto Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto Hop = (Ho + HoPerBlock - 1) / HoPerBlock * HoPerBlock;
|
||||
const auto Wop = (Wo + WoPerBlock - 1) / WoPerBlock * WoPerBlock;
|
||||
|
||||
const auto OutRightPadH = Hop - Ho;
|
||||
const auto OutRightPadW = Wop - Wo;
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0] + OutRightPadH * ConvStrideH;
|
||||
const auto InRightPadW = in_right_pads[I1] + OutRightPadW * ConvStrideW;
|
||||
|
||||
std::cerr << "OutRightPadH = " << OutRightPadH << " OutRightPadW = " << OutRightPadW
|
||||
<< std::endl;
|
||||
std::cerr << "InRightPadH = " << InRightPadH << " InRightPadW = " << InRightPadW
|
||||
<< std::endl;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_e_k_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_c_y_ho_x_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_embed_transform(make_tuple(Y, Hop), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wop), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto in_e_n_ho_wo_global_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C, Y, X)),
|
||||
make_pass_through_transform(N),
|
||||
make_pass_through_transform(Hop),
|
||||
make_pass_through_transform(Wop)),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0>{}, Sequence<3>{}, Sequence<5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_k_n_hop_wop_global_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K0, Ho, Wo, K1)),
|
||||
make_tuple(make_merge_transform(make_tuple(K0, K1)),
|
||||
make_pass_through_transform(N),
|
||||
make_pad_transform(Ho, 0, OutRightPadH),
|
||||
make_pad_transform(Wo, 0, OutRightPadW)),
|
||||
make_tuple(Sequence<1, 4>{}, Sequence<0>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto E = C * Y * X;
|
||||
|
||||
std::cerr << "Hop = " << Hop << " Wop = " << Wop << std::endl;
|
||||
|
||||
if(!((K % KPerBlock) == 0 && (Hop % HoPerBlock) == 0 && (Wop % WoPerBlock) == 0 &&
|
||||
(E % EPerBlock) == 0))
|
||||
{
|
||||
throw std::runtime_error("wrong! GEMM size no divisible");
|
||||
}
|
||||
|
||||
// hack to control index calculation when iterating over a_k_m_global tensor
|
||||
constexpr auto a_e_k_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0>{}, Sequence<0, 0, 0>{}));
|
||||
|
||||
constexpr auto a_e_k_global_move_slice_window_iterator_hack = Sequence<0, 0, 0>{};
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}));
|
||||
|
||||
constexpr auto b_e_n_ho_wo_global_move_slice_window_iterator_hack =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{};
|
||||
|
||||
// hack to control index calculation when iterating over c_m0_m1_n0_n1_global tensor
|
||||
// hack for NKHW format
|
||||
constexpr auto c_k_n_ho_wo_global_tensor_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 1, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}),
|
||||
make_tuple(Sequence<0, 2, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{},
|
||||
Sequence<0, 0, 0, 0, 0>{}));
|
||||
|
||||
// GEMM
|
||||
using gridwise_gemm = GridwiseDynamicGemm_km_kn_mn_v3<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
KPerBlock,
|
||||
HoPerBlock,
|
||||
WoPerBlock,
|
||||
EPerBlock,
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
Sequence<1, 0>,
|
||||
Sequence<1, 0>,
|
||||
0,
|
||||
ABlockTransferSrcScalarPerVector_E,
|
||||
ABlockTransferDstScalarPerVector_K,
|
||||
false, // don't move back src coordinate after threadwise copy
|
||||
Sequence<0, 2, 3, 1>,
|
||||
3,
|
||||
BThreadTransferSrcScalarPerVector_W,
|
||||
false, // don't move back src coordinate after threadwise copy, which will be fused with
|
||||
// MoveSrcSliceWindow() to save addr computation
|
||||
Sequence<0, 2, 3, 1>,
|
||||
0,
|
||||
CThreadTransferDstScalarPerVector_W,
|
||||
decltype(a_e_k_global_iterator_hacks),
|
||||
decltype(b_e_n_ho_wo_global_iterator_hacks),
|
||||
decltype(c_k_n_ho_wo_global_tensor_iterator_hacks),
|
||||
decltype(a_e_k_global_move_slice_window_iterator_hack),
|
||||
decltype(b_e_n_ho_wo_global_move_slice_window_iterator_hack)>;
|
||||
|
||||
const auto GridSize = (K / KPerBlock) * (Hop / HoPerBlock) * (Wop / WoPerBlock) * N;
|
||||
|
||||
const bool has_main_k_block_loop = (E + EPerBlock) / (2 * EPerBlock) > 1;
|
||||
|
||||
const bool has_double_tail_k_block_loop = (E / EPerBlock) % 2 == 0;
|
||||
|
||||
index_t nrepeat = 100;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
std::cout << "has_main_k_block_loop: " << has_main_k_block_loop
|
||||
<< " has_double_tail_k_block_loop: " << has_double_tail_k_block_loop
|
||||
<< std::endl;
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, true>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, true>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, true>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
run_gridwise_operation<gridwise_gemm,
|
||||
decltype(wei_e_k_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(in_e_n_ho_wo_global_desc),
|
||||
const FloatAB*,
|
||||
decltype(out_k_n_hop_wop_global_desc),
|
||||
FloatC*,
|
||||
integral_constant<bool, false>,
|
||||
integral_constant<bool, false>>;
|
||||
|
||||
launch_kernel(kernel,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
wei_e_k_global_desc,
|
||||
p_wei_global,
|
||||
in_e_n_ho_wo_global_desc,
|
||||
p_in_global,
|
||||
out_k_n_hop_wop_global_desc,
|
||||
p_out_global,
|
||||
integral_constant<bool, false>{},
|
||||
integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(in_n_c_hi_wi_global_desc,
|
||||
wei_k_c_y_x_global_desc,
|
||||
out_n_k0_ho_wo_k1_global_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,416 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_V1R2
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_V1R2
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v1r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_M1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_N1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float driver_dynamic_gemm_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AKMGridDesc& a_k_m_grid_desc,
|
||||
const BKNGridDesc& b_k_n_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto M = a_k_m_grid_desc.GetLength(I1);
|
||||
const auto N = b_k_n_grid_desc.GetLength(I1);
|
||||
const auto K = a_k_m_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k_m_grid_desc, b_k_n_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r2 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
const auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
|
||||
using AKM0M1GridDesc = decltype(a_k_m0_m1_grid_desc);
|
||||
using BKN0N1GridDesc = decltype(b_k_n0_n1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K);
|
||||
|
||||
{
|
||||
std::cout << "a_k_m0_m1_grid_desc{" << a_k_m0_m1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k_m0_m1_grid_desc.GetLength(I1) << ", " << a_k_m0_m1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k_n0_n1_grid_desc{" << b_k_n0_n1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k_n0_n1_grid_desc.GetLength(I1) << ", " << b_k_n0_n1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k_m0_m1_grid_desc,
|
||||
b_k_n0_n1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k_m0_m1_grid_desc_dev_buf(sizeof(AKM0M1GridDesc));
|
||||
DeviceMem b_k_n0_n1_grid_desc_dev_buf(sizeof(BKN0N1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k_m0_m1_grid_desc_dev_buf.ToDevice(&a_k_m0_m1_grid_desc);
|
||||
b_k_n0_n1_grid_desc_dev_buf.ToDevice(&b_k_n0_n1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r2<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AKM0M1GridDesc>,
|
||||
remove_reference_t<BKN0N1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k_m0_m1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k_n0_n1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,416 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_v1r3
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_v1r3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v1r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_km_kn_mn_v1r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
|
||||
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc));
|
||||
DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc);
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,198 +0,0 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_XDLOPS_V2R3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t MPerWave,
|
||||
index_t NPerWave,
|
||||
index_t K1,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool AThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BThreadTransferSrcResetCoordinateAfterRun,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks,
|
||||
bool CAccessOrderMRepeatNRepeat>
|
||||
__host__ float driver_dynamic_gemm_xdlops_v2r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
MPerWave,
|
||||
NPerWave,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
CAccessOrderMRepeatNRepeat>;
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m_k1_grid_desc.GetLength(I1) << ", " << a_k0_m_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n_k1_grid_desc{" << b_k0_n_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n_k1_grid_desc.GetLength(I1) << ", " << b_k0_n_k1_grid_desc.GetLength(I2)
|
||||
<< "}" << std::endl;
|
||||
|
||||
std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m_n_grid_desc.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc);
|
||||
|
||||
const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc);
|
||||
|
||||
const auto kernel = kernel_dynamic_gemm_xdlops_v2r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0MK1GridDesc>,
|
||||
remove_reference_t<BK0NK1GridDesc>,
|
||||
remove_reference_t<CM0M1M2NGridDesc>,
|
||||
remove_reference_t<CBlockClusterAdaptor>>;
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m_k1_grid_desc,
|
||||
b_k0_n_k1_grid_desc,
|
||||
c_m0_m1_m2_n_grid_desc,
|
||||
c_block_cluster_adaptor);
|
||||
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc));
|
||||
DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc));
|
||||
DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc));
|
||||
DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor));
|
||||
|
||||
a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc);
|
||||
b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc);
|
||||
c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc);
|
||||
c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor);
|
||||
|
||||
float ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
#endif
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -7,9 +7,12 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
// GemmM0 = 1
|
||||
// GemmM1 = K
|
||||
// GemmN0 = N0
|
||||
// GemmN1 = (N / N0) * Ho * Wo
|
||||
// GemmK0 = (C / C0) * Y * X
|
||||
// GemmK1 = C0
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
|
||||
@@ -46,7 +46,7 @@ struct DynamicPassThrough
|
||||
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
const UpIdx&,
|
||||
Number<Hack>)
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
|
||||
@@ -136,7 +136,7 @@ struct DynamicPad
|
||||
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
const UpIdx&,
|
||||
Number<Hack>)
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
|
||||
@@ -227,7 +227,7 @@ struct DynamicLeftPad
|
||||
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
const UpIdx&,
|
||||
Number<Hack>)
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
|
||||
@@ -318,7 +318,7 @@ struct DynamicRightPad
|
||||
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
const UpIdx&,
|
||||
Number<Hack>)
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
|
||||
@@ -420,7 +420,7 @@ struct DynamicEmbed
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
const UpIdx&,
|
||||
Number<Hack>) const
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == NDimUp &&
|
||||
@@ -1096,7 +1096,7 @@ struct DynamicMerge_v2_magic_division
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
const UpIdxDiff&,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
Number<Hack>) const
|
||||
@@ -1254,7 +1254,7 @@ struct DynamicMerge_v2r2_magic_division
|
||||
typename UpIdx,
|
||||
index_t Hack>
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
const UpIdxDiff&,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
Number<Hack>) const
|
||||
@@ -1383,7 +1383,7 @@ struct DynamicUnMerge
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
const UpIdx&,
|
||||
Number<Hack>) const
|
||||
{
|
||||
CalculateLowerIndex(idx_diff_low, idx_diff_up);
|
||||
@@ -1597,7 +1597,7 @@ struct DynamicVectorize
|
||||
__host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
const UpIdx&,
|
||||
Number<Hack>) const
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
|
||||
@@ -1654,7 +1654,7 @@ struct DynamicSlice
|
||||
|
||||
__host__ __device__ constexpr DynamicSlice() = default;
|
||||
|
||||
__host__ __device__ constexpr DynamicSlice(const LowLength& low_length,
|
||||
__host__ __device__ constexpr DynamicSlice(const LowLength&,
|
||||
const SliceBegin& slice_begin,
|
||||
const SliceEnd& slice_end)
|
||||
: up_lengths_{make_tuple(slice_end - slice_begin)},
|
||||
@@ -1687,7 +1687,7 @@ struct DynamicSlice
|
||||
__host__ __device__ static void UpdateLowerIndex(LowIdxDiff& idx_diff_low,
|
||||
const UpIdxDiff& idx_diff_up,
|
||||
LowIdx& idx_low,
|
||||
const UpIdx& idx_up_new,
|
||||
const UpIdx&,
|
||||
Number<Hack>)
|
||||
{
|
||||
static_assert(LowIdxDiff::Size() == 1 && UpIdxDiff::Size() == 1 && LowIdx::Size() == 1 &&
|
||||
@@ -1709,8 +1709,7 @@ struct DynamicSlice
|
||||
}
|
||||
|
||||
template <typename UpIdx>
|
||||
__host__ __device__ constexpr bool
|
||||
IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& idx_up) const
|
||||
__host__ __device__ constexpr bool IsValidUpperIndexMappedToValidLowerIndex(const UpIdx&) const
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -317,7 +317,7 @@ transform_dynamic_tensor_descriptor(const OldTensorDescriptor& old_tensor_desc,
|
||||
NewUpperDimensionNewVisibleIdss{});
|
||||
|
||||
static_assert(is_valid_sequence_map<decltype(all_old_top_ids)>::value &&
|
||||
is_valid_sequence_map<decltype(all_old_top_ids)>::value,
|
||||
is_valid_sequence_map<decltype(all_new_top_ids)>::value,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
@@ -395,7 +395,6 @@ __host__ __device__ constexpr auto make_dynamic_tensor_coordinate(const TensorDe
|
||||
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
|
||||
constexpr auto visible_dim_ids = TensorDesc::GetVisibleDimensionIds();
|
||||
|
||||
MultiIndex<ndim_hidden> idx_hidden;
|
||||
@@ -491,11 +490,8 @@ template <typename TensorDesc, typename TensorCoord, typename TensorCoordIterato
|
||||
__host__ __device__ constexpr void move_dynamic_tensor_coordinate(
|
||||
const TensorDesc& tensor_desc, TensorCoord& coord, const TensorCoordIterator& coord_iterator)
|
||||
{
|
||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||
constexpr index_t ndim_visible = TensorDesc::GetNumOfVisibleDimension();
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
|
||||
using HiddenIndex = MultiIndex<ndim_hidden>;
|
||||
constexpr index_t ndim_hidden = TensorDesc::GetNumOfHiddenDimension();
|
||||
constexpr index_t ntransform = TensorDesc::GetNumOfTransform();
|
||||
|
||||
// this is what needs to be calculated
|
||||
auto idx_diff_hidden = make_zero_multi_index<ndim_hidden>();
|
||||
|
||||
@@ -236,15 +236,15 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
|
||||
// shift
|
||||
constexpr index_t adaptor0_max_hidden_id = [&]() {
|
||||
index_t adaptor0_max_hidden_id = NumericLimits<index_t>::Min();
|
||||
index_t adaptor0_max_hidden_id_ = NumericLimits<index_t>::Min();
|
||||
|
||||
static_for<0, TensorAdaptor0::GetNumOfTransform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfLowerDimension();
|
||||
|
||||
static_for<0, ndim_low, 1>{}([&](auto idim_low) {
|
||||
adaptor0_max_hidden_id =
|
||||
math::max(adaptor0_max_hidden_id,
|
||||
adaptor0_max_hidden_id_ =
|
||||
math::max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::GetLowerDimensionHiddenIdss()[itran][idim_low].value);
|
||||
});
|
||||
|
||||
@@ -252,17 +252,17 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
TensorAdaptor0{}.GetTransforms()[itran].GetNumOfUpperDimension();
|
||||
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor0_max_hidden_id =
|
||||
math::max(adaptor0_max_hidden_id,
|
||||
adaptor0_max_hidden_id_ =
|
||||
math::max(adaptor0_max_hidden_id_,
|
||||
TensorAdaptor0::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor0_max_hidden_id;
|
||||
return adaptor0_max_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_min_hidden_id = [&]() {
|
||||
index_t adaptor1_min_hidden_id = NumericLimits<index_t>::Max();
|
||||
index_t adaptor1_min_hidden_id_ = NumericLimits<index_t>::Max();
|
||||
|
||||
static_for<0, TensorAdaptor1::GetNumOfTransform(), 1>{}([&](auto itran) {
|
||||
constexpr index_t ndim_low =
|
||||
@@ -285,7 +285,7 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
|
||||
if(!is_bottom_dim)
|
||||
{
|
||||
adaptor1_min_hidden_id = math::min(adaptor1_min_hidden_id, low_dim_hidden_id);
|
||||
adaptor1_min_hidden_id_ = math::min(adaptor1_min_hidden_id_, low_dim_hidden_id);
|
||||
}
|
||||
});
|
||||
|
||||
@@ -294,13 +294,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
|
||||
// get the min of all upper dimensions
|
||||
static_for<0, ndim_up, 1>{}([&](auto idim_up) {
|
||||
adaptor1_min_hidden_id =
|
||||
math::min(adaptor1_min_hidden_id,
|
||||
adaptor1_min_hidden_id_ =
|
||||
math::min(adaptor1_min_hidden_id_,
|
||||
TensorAdaptor1::GetUpperDimensionHiddenIdss()[itran][idim_up].value);
|
||||
});
|
||||
});
|
||||
|
||||
return adaptor1_min_hidden_id;
|
||||
return adaptor1_min_hidden_id_;
|
||||
}();
|
||||
|
||||
constexpr index_t adaptor1_hidden_id_shift =
|
||||
@@ -321,11 +321,11 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
// sequence in, sequence out
|
||||
constexpr auto low_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto low_dim_hidden_ids_1_mod = to_multi_index(low_dim_hidden_ids_1);
|
||||
auto low_dim_hidden_ids_1_mod_ = to_multi_index(low_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id so every dim id is unique
|
||||
static_for<0, ndim_low_1, 1>{}([&](auto idim_low_1) {
|
||||
low_dim_hidden_ids_1_mod(idim_low_1) += adaptor1_hidden_id_shift;
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
// match hidden id
|
||||
@@ -335,13 +335,13 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
if constexpr(low_dim_hidden_ids_1[idim_low_1] ==
|
||||
TensorAdaptor1::GetBottomDimensionHiddenIds()[idim_bottom_1])
|
||||
{
|
||||
low_dim_hidden_ids_1_mod(idim_low_1) =
|
||||
low_dim_hidden_ids_1_mod_(idim_low_1) =
|
||||
TensorAdaptor0::GetTopDimensionHiddenIds()[idim_bottom_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
return low_dim_hidden_ids_1_mod;
|
||||
return low_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
@@ -367,14 +367,14 @@ __host__ __device__ constexpr auto chain_tensor_adaptors(const TensorAdaptor0& a
|
||||
// sequence in, constexpr tuple out
|
||||
constexpr auto up_dim_hidden_ids_1_mod = [&]() constexpr
|
||||
{
|
||||
auto up_dim_hidden_ids_1_mod = to_multi_index(up_dim_hidden_ids_1);
|
||||
auto up_dim_hidden_ids_1_mod_ = to_multi_index(up_dim_hidden_ids_1);
|
||||
|
||||
// shift hidden id
|
||||
static_for<0, ndim_up_1, 1>{}([&](auto idim_up_1) {
|
||||
up_dim_hidden_ids_1_mod(idim_up_1) += adaptor1_hidden_id_shift;
|
||||
up_dim_hidden_ids_1_mod_(idim_up_1) += adaptor1_hidden_id_shift;
|
||||
});
|
||||
|
||||
return up_dim_hidden_ids_1_mod;
|
||||
return up_dim_hidden_ids_1_mod_;
|
||||
}
|
||||
();
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace ck {
|
||||
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
|
||||
@@ -14,7 +14,7 @@ namespace ck {
|
||||
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_V2R2_HPP
|
||||
#define CK_BLOCKWISE_GEMM_V2R2_HPP
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_contraction.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -40,7 +40,7 @@ template <index_t BlockSize,
|
||||
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
|
||||
BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
@@ -140,7 +140,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
static constexpr auto b_k_n0_n1_block_desc_ = MakeBKN0N1BlockDescriptor(BKNBlockDesc{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
|
||||
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2()
|
||||
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
@@ -183,7 +183,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t GetABlockAlignment() { return M1PerThreadM11; }
|
||||
@@ -207,21 +207,21 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
|
||||
a_k_m0_m1_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
|
||||
b_k_n0_n1_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
CM0M1N0N1ThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<1, M1PerThreadM11>,
|
||||
Sequence<1, N1PerThreadN11>>{};
|
||||
ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_k_m0_m1_thread_desc_),
|
||||
decltype(b_k_n0_n1_thread_desc_),
|
||||
CM0M1N0N1ThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<1, M1PerThreadM11>,
|
||||
Sequence<1, N1PerThreadN11>>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k_m0_m1_block_desc_,
|
||||
@@ -1,10 +1,10 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_V2R3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_V2R3_HPP
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_contraction.hpp"
|
||||
#include "threadwise_contraction_dlops.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -21,6 +21,7 @@ namespace ck {
|
||||
// 1. CThreadDesc_BM0_BM11_BN0_BN11 is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// BM10BN10ThreadClusterBM10Xs::Size() = BM10BN10ThreadClusterBN10Xs::Size() == 2
|
||||
// BM0 = BN0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
@@ -31,16 +32,16 @@ template <index_t BlockSize,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
index_t BM10BN10ThreadClusterBM100,
|
||||
index_t BM10BN10ThreadClusterBN100,
|
||||
index_t BM10BN10ThreadClusterBM101,
|
||||
index_t BM10BN10ThreadClusterBN101,
|
||||
typename BM10BN10ThreadClusterBM10Xs, // Sequence<BM10BN10ThreadClusterBM100,
|
||||
// BM10BN10ThreadClusterBM101, ...>
|
||||
typename BM10BN10ThreadClusterBN10Xs, // Sequence<BM10BN10ThreadClusterBN100,
|
||||
// BM10BN10ThreadClusterBN101, ...>
|
||||
index_t AThreadCopyScalarPerVector_BM11,
|
||||
index_t BThreadCopyScalarPerVector_BN11,
|
||||
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
@@ -56,19 +57,17 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
|
||||
static constexpr index_t BM = ABlockDesc_BK0_BM_BK1{}.GetLength(I1);
|
||||
static constexpr index_t BN = BBlockDesc_BK0_BN_BK1{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t BM100 = BM10BN10ThreadClusterBM100;
|
||||
static constexpr index_t BN100 = BM10BN10ThreadClusterBN100;
|
||||
static constexpr index_t BM100 = BM10BN10ThreadClusterBM10Xs{}[I0];
|
||||
static constexpr index_t BN100 = BM10BN10ThreadClusterBN10Xs{}[I0];
|
||||
|
||||
static constexpr index_t BM101 = BM10BN10ThreadClusterBM101;
|
||||
static constexpr index_t BN101 = BM10BN10ThreadClusterBN101;
|
||||
static constexpr index_t BM101 = BM10BN10ThreadClusterBM10Xs{}[I1];
|
||||
static constexpr index_t BN101 = BM10BN10ThreadClusterBN10Xs{}[I1];
|
||||
|
||||
static constexpr index_t BM11 = BM1PerThreadBM11;
|
||||
static constexpr index_t BN11 = BN1PerThreadBN11;
|
||||
|
||||
static constexpr index_t BM1 =
|
||||
BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11;
|
||||
static constexpr index_t BN1 =
|
||||
BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11;
|
||||
static constexpr index_t BM1 = BM100 * BM101 * BM11;
|
||||
static constexpr index_t BN1 = BN100 * BN101 * BN11;
|
||||
|
||||
static constexpr index_t BM0 = BM / BM1;
|
||||
static constexpr index_t BN0 = BN / BN1;
|
||||
@@ -149,7 +148,7 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
|
||||
MakeBBlockDescriptor_BK0_BN0_BN1_BK1(BBlockDesc_BK0_BN_BK1{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
__device__ BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
@@ -170,6 +169,11 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
|
||||
BBlockDesc_BK0_BN_BK1{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(BM10BN10ThreadClusterBM10Xs::Size() == 2 &&
|
||||
BM10BN10ThreadClusterBN10Xs::Size() == 2,
|
||||
"wrong!");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(BM0 == 2 && BN0 == 2, "wrong");
|
||||
}
|
||||
@@ -195,14 +199,14 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(thread_id, 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
template <typename CThreadDesc_BM0_BM11_BN0_BN11,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11& c_m0_m1_n0_n1_thread_desc,
|
||||
__device__ void Run(const CThreadDesc_BM0_BM11_BN0_BN11&,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
@@ -216,13 +220,13 @@ struct BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_
|
||||
CThreadDesc_BM0_BM11_BN0_BN11{}.GetLength(I2) == BN0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatA>(
|
||||
a_thread_desc_bk0_bm0_bm1_bk1_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatB>(
|
||||
b_thread_desc_bk0_bn0_bn1_bk1_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_contraction =
|
||||
ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1<
|
||||
FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
@@ -1,8 +1,8 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_V3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_V3_HPP
|
||||
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_gemm_v3.hpp"
|
||||
#include "threadwise_gemm_dlops_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -19,7 +19,7 @@ template <index_t BlockSize,
|
||||
index_t EPerThreadLoop,
|
||||
index_t ThreadGemmADataPerRead_K,
|
||||
index_t ThreadGemmBDataPerRead_W>
|
||||
struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
|
||||
{
|
||||
struct MatrixIndex
|
||||
{
|
||||
@@ -51,7 +51,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
ThreadGemmADataPerRead_K,
|
||||
1>;
|
||||
|
||||
__device__ BlockwiseGemm_km_kn_m0m1n0n1_v3()
|
||||
__device__ BlockwiseGemmDlops_km_kn_m0m1n0n1_v3()
|
||||
: c_thread_begin_mtx_idx_{GetBeginOfThreadMatrixC(get_thread_local_1d_id())},
|
||||
a_thread_copy_{make_tuple(0, c_thread_begin_mtx_idx_.k * KPerThread)}
|
||||
{
|
||||
@@ -138,16 +138,17 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
|
||||
static_assert(WPerThread % WoPerThreadSubC == 0, "");
|
||||
|
||||
// thread A buffer for GEMM
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()> a_thread_buf;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatA, a_thread_mtx_.GetElementSpaceSize()>
|
||||
a_thread_buf;
|
||||
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemm_km_kn_mn_v3<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_mtx_),
|
||||
decltype(b_thread_mtx_),
|
||||
decltype(c_thread_mtx_),
|
||||
HoPerThreadSubC,
|
||||
WoPerThreadSubC>{};
|
||||
constexpr auto threadwise_gemm = ThreadwiseGemmDlops_km_kn_mn_v3<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_mtx_),
|
||||
decltype(b_thread_mtx_),
|
||||
decltype(c_thread_mtx_),
|
||||
HoPerThreadSubC,
|
||||
WoPerThreadSubC>{};
|
||||
|
||||
static_for<0, EPerBlock, EPerThreadLoop>{}([&](auto e_begin) {
|
||||
static_for<0, KPerThread, KPerThreadSubC>{}([&](auto k_begin) {
|
||||
@@ -1,514 +0,0 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_V2_HPP
|
||||
#define CK_BLOCKWISE_GEMM_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_gemm_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. ABlockDesc is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. ABlockDesc is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ABlockDesc,
|
||||
typename BBlockDesc,
|
||||
typename CThreadDesc,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
index_t AThreadCopyScalarPerVector_M1,
|
||||
index_t BThreadCopyScalarPerVector_N1,
|
||||
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
|
||||
BBlockDesc::IsKnownAtCompileTime() &&
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
|
||||
M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
|
||||
{
|
||||
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
|
||||
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
|
||||
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
// 4-d data space into 4-d thread space
|
||||
// upper: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1, M1N1ThreadClusterN10 *
|
||||
// M1N1ThreadClusterN11} lower: {M0, M1, N0, N1}
|
||||
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_vectorize_transform(M0, 1),
|
||||
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
|
||||
make_vectorize_transform(N0, 1),
|
||||
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// thread position 4-d thread space
|
||||
// upper: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
|
||||
// M1N1ThreadClusterN11} lower: {1, M1N1ThreadClusterM10 * M1N1ThreadClusterM11, 1,
|
||||
// M1N1ThreadClusterN10 * M1N1ThreadClusterN11}
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
|
||||
|
||||
// 4-d thread space to 1-d thread space
|
||||
// upper: {BlockSize}
|
||||
// lower: {M1N1ThreadClusterM10, M1N1ThreadClusterM11, M1N1ThreadClusterN10,
|
||||
// M1N1ThreadClusterN11}
|
||||
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11))),
|
||||
make_tuple(Sequence<0, 2, 1, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
|
||||
|
||||
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
CThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<M0_, M1PerThread>,
|
||||
Sequence<N0_, N1PerThread>>{};
|
||||
|
||||
constexpr index_t K = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
static_for<0, K, KPerThread>{}([&](auto k) {
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
});
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
|
||||
|
||||
// A[K, M0, M1]
|
||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
|
||||
|
||||
// B[K, N0, N1]
|
||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPerThread, M0_, M1PerThread>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPerThread, N0_, N1PerThread>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N1,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. ABlockDesc is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. ABlockDesc is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ABlockDesc,
|
||||
typename BBlockDesc,
|
||||
typename CThreadDesc,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
index_t AThreadCopyScalarPerVector_M1,
|
||||
index_t BThreadCopyScalarPerVector_N1,
|
||||
typename std::enable_if<ABlockDesc::IsKnownAtCompileTime() &&
|
||||
BBlockDesc::IsKnownAtCompileTime() &&
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v2_pipeline_2x2()
|
||||
: c_thread_origin_data_idx_{CalculateCThreadOriginDataIndex(get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1])},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3])}
|
||||
{
|
||||
static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime() &&
|
||||
CThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == M1N1ThreadClusterM11 * M1N1ThreadClusterM10 *
|
||||
M1N1ThreadClusterN11 * M1N1ThreadClusterN10,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(ABlockDesc{}.GetLength(I1) == 2 && BBlockDesc{}.GetLength(I1) == 2 &&
|
||||
CThreadDesc{}.GetLength(I0) == 2 && CThreadDesc{}.GetLength(I2) == 2,
|
||||
"wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCThreadOriginDataIndex(index_t thread_id)
|
||||
{
|
||||
constexpr index_t M0 = ABlockDesc{}.GetLength(I1);
|
||||
constexpr index_t N0 = BBlockDesc{}.GetLength(I1);
|
||||
constexpr index_t M1 = ABlockDesc{}.GetLength(I2);
|
||||
constexpr index_t N1 = BBlockDesc{}.GetLength(I2);
|
||||
|
||||
// 4-d data space into 4-d thread space
|
||||
constexpr auto adaptor0 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_vectorize_transform(M0, 1),
|
||||
make_vectorize_transform(M1PerThread, M1 / M1PerThread),
|
||||
make_vectorize_transform(N0, 1),
|
||||
make_vectorize_transform(N1PerThread, N1 / N1PerThread)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// thread position 4-d thread space
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(M1N1ThreadClusterM10, M1N1ThreadClusterM11)),
|
||||
make_freeze_transform(make_multi_index(0)),
|
||||
make_unmerge_transform(make_tuple(M1N1ThreadClusterN10, M1N1ThreadClusterN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<>{}, Sequence<0, 1>{}, Sequence<>{}, Sequence<2, 3>{}));
|
||||
|
||||
// 4-d thread space to 1-d thread space
|
||||
constexpr auto adaptor2 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11))),
|
||||
make_tuple(Sequence<0, 2, 1, 3>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
constexpr auto cluster_desc = chain_tensor_adaptors(adaptor0, adaptor1, adaptor2);
|
||||
|
||||
return cluster_desc.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id()));
|
||||
}
|
||||
|
||||
template <typename ABlockBuffer, typename BBlockBuffer, typename CThreadBuffer>
|
||||
__device__ void Run(const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatA>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatB>(b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_thread_desc_),
|
||||
decltype(b_thread_desc_),
|
||||
CThreadDesc,
|
||||
Sequence<KPerThread>,
|
||||
Sequence<1, M1PerThread>,
|
||||
Sequence<1, N1PerThread>>{};
|
||||
|
||||
constexpr index_t K = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(I0, I1, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(I0, I1, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of k
|
||||
static_for<KPerThread, K, KPerThread>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I0, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I0, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(BBlockDesc{},
|
||||
make_tuple(k, I1, I0),
|
||||
b_block_buf,
|
||||
b_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(ABlockDesc{},
|
||||
make_tuple(k, I1, I0),
|
||||
a_block_buf,
|
||||
a_thread_desc_,
|
||||
make_tuple(I0, I1, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr index_t M0_ = ABlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N0_ = BBlockDesc{}.GetLength(I1);
|
||||
|
||||
// A[K, M0, M1]
|
||||
static constexpr auto a_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0_>{}, Number<M1PerThread>{}));
|
||||
|
||||
// B[K, N0, N1]
|
||||
static constexpr auto b_thread_desc_ = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0_>{}, Number<N1PerThread>{}));
|
||||
|
||||
using AThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatA,
|
||||
FloatA,
|
||||
ABlockDesc,
|
||||
decltype(a_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThread>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
AThreadCopyScalarPerVector_M1,
|
||||
1>;
|
||||
|
||||
using BThreadCopy =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v4<FloatB,
|
||||
FloatB,
|
||||
BBlockDesc,
|
||||
decltype(b_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThread>,
|
||||
Sequence<0, 1, 2>,
|
||||
2,
|
||||
BThreadCopyScalarPerVector_N1,
|
||||
1>;
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -138,10 +138,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatAB>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatAB>(b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
@@ -358,10 +358,10 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
auto a_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatAB>(a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf =
|
||||
make_static_buffer<AddressSpace::Vgpr, FloatAB>(b_thread_desc_.GetElementSpaceSize());
|
||||
auto a_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
a_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAB>(
|
||||
b_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0);
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r3.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
@@ -25,7 +25,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_contraction_v1r2(
|
||||
kernel_dynamic_contraction_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -55,7 +55,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGridDesc_GK0_GM0_GM1_GK1,
|
||||
typename BGridDesc_GK0_GN0_GN1_GK1,
|
||||
typename CGridDesc_GM0_GM1_GN0_GN1,
|
||||
@@ -65,10 +65,8 @@ template <index_t BlockSize,
|
||||
index_t BM1PerThreadBM11,
|
||||
index_t BN1PerThreadBN11,
|
||||
index_t BK0PerThread,
|
||||
index_t BM10BN10ThreadClusterBM100,
|
||||
index_t BM10BN10ThreadClusterBN100,
|
||||
index_t BM10BN10ThreadClusterBM101,
|
||||
index_t BM10BN10ThreadClusterBN101,
|
||||
typename BM10BN10ThreadClusterBM10Xs,
|
||||
typename BM10BN10ThreadClusterBN10Xs,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -91,7 +89,7 @@ template <index_t BlockSize,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
struct GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -252,9 +250,11 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<BM10BN10ThreadClusterBM100 * BM10BN10ThreadClusterBM101 * BM1PerThreadBM11>{};
|
||||
Number<container_reduce(BM10BN10ThreadClusterBM10Xs{}, math::multiplies_v2{}, I1) *
|
||||
BM1PerThreadBM11>{};
|
||||
constexpr auto BN1 =
|
||||
Number<BM10BN10ThreadClusterBN100 * BM10BN10ThreadClusterBN101 * BN1PerThreadBN11>{};
|
||||
Number<container_reduce(BM10BN10ThreadClusterBN10Xs{}, math::multiplies_v2{}, I1) *
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
@@ -331,11 +331,11 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_grid_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
const auto GK0 = a_grid_desc_gk0_gm0_gm10_gm11_gk1.GetLength(I0);
|
||||
@@ -387,7 +387,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<GK0PerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
@@ -411,7 +411,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<GK0PerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
@@ -439,7 +439,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
|
||||
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
@@ -449,10 +449,8 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11>{};
|
||||
|
||||
@@ -474,7 +472,7 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_thread_desc_bm0_bm1_bn0_bn1.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
@@ -488,15 +486,15 @@ struct GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(GK0PerBlock, 0, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_block_desc_gk0_gm0_gm10_gm11_gk1.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_block_desc_gk0_gn0_gn10_gn11_gk1.GetElementSpaceSize());
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R2_HPP
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_DLOPS_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r2.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
@@ -26,7 +26,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r2(
|
||||
kernel_dynamic_gemm_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -52,8 +52,8 @@ __global__ void
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// pass tensor descriptor by CONSTANT void pointer
|
||||
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
@@ -68,16 +68,16 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r2(
|
||||
kernel_dynamic_gemm_dlops_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k_m0_m1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k_n0_n1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// first cast void CONSTANT void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k_m0_m1_grid_desc =
|
||||
@@ -113,7 +113,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AKMGridDesc,
|
||||
typename BKNGridDesc,
|
||||
typename CMNGridDesc,
|
||||
@@ -151,7 +151,7 @@ template <index_t BlockSize,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_mn_v1r2
|
||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -326,11 +326,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k_m0_m1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k_n0_n1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K = a_k_m0_m1_grid_desc.GetLength(I0);
|
||||
@@ -373,7 +373,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1>,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
@@ -399,7 +399,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1>,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
@@ -429,21 +429,21 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k_m_block_desc),
|
||||
decltype(b_k_n_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
@@ -462,7 +462,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
@@ -487,17 +487,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r2
|
||||
constexpr auto b_k_n0_n1_global_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf =
|
||||
make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_aligned_space_size,
|
||||
a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf =
|
||||
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_aligned_space_size,
|
||||
b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k_m0_m1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k_n0_n1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
@@ -5,7 +5,7 @@
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r3.hpp"
|
||||
#include "blockwise_gemm_dlops_v2r3.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
@@ -26,7 +26,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r3(
|
||||
kernel_dynamic_gemm_dlops_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
@@ -52,8 +52,8 @@ __global__ void
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// pass tensor descriptor by CONSTANT void pointer
|
||||
// CONSTANT is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
@@ -68,16 +68,16 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r3(
|
||||
kernel_dynamic_gemm_dlops_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k0_m0_m1_k1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k0_n0_n1_k1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
const void CONSTANT* p_a_k0_m0_m1_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n0_n1_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// first cast void CONSTANT void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
@@ -113,7 +113,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
@@ -123,10 +123,8 @@ template <index_t BlockSize,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename M11N11ThreadClusterM110Xs,
|
||||
typename M11N11ThreadClusterN110Xs,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -149,7 +147,7 @@ template <index_t BlockSize,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
@@ -277,9 +275,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
|
||||
M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
@@ -333,11 +333,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
@@ -383,7 +383,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
@@ -407,7 +407,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
@@ -435,7 +435,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
@@ -445,15 +445,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M11N11ThreadClusterM110Xs,
|
||||
M11N11ThreadClusterN110Xs,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
decltype(blockwise_gemm)::GetCThreadTensorLengths_BM0_BM1_BN0_BN1();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
@@ -470,7 +468,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
auto c_thread_buf = make_static_buffer<AddressSpaceEnum_t::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
@@ -484,17 +482,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf =
|
||||
make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_aligned_space_size,
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf =
|
||||
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_aligned_space_size,
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
@@ -610,10 +608,12 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M11 =
|
||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
||||
constexpr index_t N11 =
|
||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
||||
constexpr auto M11 =
|
||||
Number<container_reduce(M11N11ThreadClusterM110Xs{}, math::multiplies_v2{}, I1) *
|
||||
M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<container_reduce(M11N11ThreadClusterN110Xs{}, math::multiplies_v2{}, I1) *
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr index_t M10 = MPerBlockM1 / M11;
|
||||
constexpr index_t N10 = NPerBlockN1 / N11;
|
||||
@@ -631,7 +631,8 @@ struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
blockwise_gemm.CalculateCThreadOriginOnBlock_BM0_BM1_BN0_BN1(
|
||||
get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
@@ -7,7 +7,7 @@
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "blockwise_gemm_v3.hpp"
|
||||
#include "blockwise_gemm_dlops_v3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -15,7 +15,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AGlobalDesc,
|
||||
typename BGlobalDesc,
|
||||
typename CGlobalDesc,
|
||||
@@ -47,7 +47,7 @@ template <index_t BlockSize,
|
||||
typename CGlobalIteratorHacks,
|
||||
typename AGlobalMoveSliceWindowIteratorHacks,
|
||||
typename BGlobalMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
struct GridwiseDynamicGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
@@ -84,11 +84,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_global, a_e_k_global_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_global, b_e_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
auto c_global_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_global, c_k_n_ho_wo_global_desc.GetElementSpaceSize());
|
||||
|
||||
constexpr auto E = EPerBlock * 3 * 3;
|
||||
@@ -100,7 +100,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
const auto Ho = b_e_n_ho_wo_global_desc.GetLength(I2);
|
||||
const auto Wo = b_e_n_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
// divide block work by [M, N]
|
||||
// divide block work by [M, N]
|
||||
#if 0
|
||||
const auto k_block_work_num = K / Number<KPerBlock>{};
|
||||
const auto ho_block_work_num = Ho / Number<HoPerBlock>{};
|
||||
@@ -152,19 +152,20 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
Number<KPerThread>{}, Number<1>{}, Number<HoPerThread>{}, Number<WoPerThread>{}));
|
||||
|
||||
auto blockwise_gemm = BlockwiseGemm_km_kn_m0m1n0n1_v3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_e_k_block_desc),
|
||||
decltype(b_e_n_ho_wo_block_desc),
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K>{};
|
||||
auto blockwise_gemm =
|
||||
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_e_k_block_desc),
|
||||
decltype(b_e_n_ho_wo_block_desc),
|
||||
decltype(c_k_n_ho_wo_thread_desc),
|
||||
KPerThread,
|
||||
HoPerThread,
|
||||
WoPerThread,
|
||||
EPerThread,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K>{};
|
||||
|
||||
auto c_thread_mtx_index = blockwise_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
@@ -184,7 +185,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<E, KPerBlock>,
|
||||
ABlockTransferThreadSliceLengths_E_K,
|
||||
ABlockTransferThreadClusterLengths_E_K,
|
||||
@@ -225,11 +226,13 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
true>(b_e_n_ho_wo_global_desc,
|
||||
make_multi_index(0, 0, ho_thread_data_on_global, wo_thread_data_on_global));
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(p_shared_block,
|
||||
a_e_k_desc.GetElementSpaceSize());
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_shared_block, a_e_k_desc.GetElementSpaceSize());
|
||||
|
||||
// register allocation for output
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatAcc, c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAcc,
|
||||
c_k_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
c_thread_buf;
|
||||
|
||||
// initialize output thread tensor
|
||||
@@ -252,7 +255,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
|
||||
BGlobalMoveSliceWindowIteratorHacks{};
|
||||
|
||||
// double regsiter buffer for b
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatAB, b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
FloatAB,
|
||||
b_e_n_ho_wo_thread_desc.GetElementSpaceSize()>
|
||||
b_thread_even_buf, b_thread_odd_buf;
|
||||
|
||||
// LDS double buffer: preload data
|
||||
@@ -61,10 +61,10 @@ __global__ void
|
||||
kernel_dynamic_gemm_xdlops_v2r3(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void __CONSTANT__* p_c_block_cluster_adaptor)
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
@@ -95,7 +95,7 @@ template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
InMemoryDataOperationEnum_t CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
@@ -274,11 +274,11 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto a_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_a_grid, a_k0_m_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
const auto b_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum_t::Global>(
|
||||
p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
@@ -312,7 +312,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, MPerBlock, K1>,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
@@ -339,7 +339,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy =
|
||||
BlockwiseDynamicTensorSliceTransfer_v4<BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
Sequence<KPerBlock, NPerBlock, K1>,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
@@ -413,7 +413,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto c_mr_nr_blk_desc = make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<MRepeat>{}, Number<NRepeat>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr,
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr,
|
||||
vector_type<FloatAcc, BlkSize>,
|
||||
c_mr_nr_blk_desc.GetElementSpaceSize()>
|
||||
c_thread_buf;
|
||||
@@ -442,9 +442,9 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
constexpr auto b_k0_n_k1_grid_move_slice_window_iterator_hack =
|
||||
BGridMoveSliceWindowIteratorHacks{};
|
||||
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto a_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
auto b_block_buf = make_dynamic_buffer<AddressSpaceEnum_t::Lds>(
|
||||
p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// preload data into LDS
|
||||
@@ -515,7 +515,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
Number<M2>{},
|
||||
Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, c_m0_m1_m2_n_thread_desc.GetElementSpaceSize()>
|
||||
c_blk_buf_;
|
||||
|
||||
static_for<0, MRepeat, 1>{}([&](auto mr_i) {
|
||||
@@ -585,7 +585,7 @@ struct GridwiseDynamicGemm_k0mk1_k0nk1_mn_xdlops_v2r3
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(
|
||||
I1, I1, I1, I1, Number<M0>{}, Number<1>{}, Number<M2>{}, Number<1>{}));
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr, FloatC, BlkSize> c_blk_buf_;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, FloatC, BlkSize> c_blk_buf_;
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_THREADWISE_CONTRACTION_HPP
|
||||
#define CK_THREADWISE_CONTRACTION_HPP
|
||||
#ifndef CK_THREADWISE_CONTRACTION_DLOPS_HPP
|
||||
#define CK_THREADWISE_CONTRACTION_DLOPS_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "math.hpp"
|
||||
@@ -25,9 +25,9 @@ template <typename FloatA,
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
|
||||
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1()
|
||||
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
|
||||
{
|
||||
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
@@ -71,8 +71,6 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto TK = TKLengths{}[I0];
|
||||
constexpr auto TM0 = TMLengths{}[I0];
|
||||
@@ -131,9 +129,9 @@ template <typename FloatA,
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
||||
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
||||
{
|
||||
__device__ constexpr ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
||||
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
||||
{
|
||||
static_assert(AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
@@ -177,8 +175,6 @@ struct ThreadwiseContraction_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_T
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t TK0 = TKLengths{}[I0];
|
||||
constexpr index_t TK1 = TKLengths{}[I1];
|
||||
@@ -54,7 +54,7 @@ template <typename SrcData,
|
||||
typename DimAccessOrder,
|
||||
index_t DstVectorDim,
|
||||
index_t DstScalarPerVector,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
@@ -159,9 +159,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
@@ -170,10 +170,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
@@ -186,10 +186,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
|
||||
});
|
||||
|
||||
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
|
||||
return dst_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
typename vector_type_maker<DstData, DstScalarPerVector>::type dst_vector;
|
||||
@@ -217,17 +215,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim;
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim;
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
@@ -295,9 +293,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
@@ -306,10 +304,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in Run(), if it has not being reset by
|
||||
@@ -321,19 +319,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
|
||||
return dst_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step;
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; });
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step;
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
@@ -478,9 +474,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_idx[I0];
|
||||
@@ -489,10 +485,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
@@ -505,10 +501,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
: ordered_access_lengths[i] - 1 - ordered_access_idx[i];
|
||||
});
|
||||
|
||||
auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
|
||||
return src_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
typename vector_type_maker<SrcData, SrcScalarPerVector>::type src_vector;
|
||||
@@ -534,17 +528,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim;
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
move_on_dim_(i) = ordered_access_idx[i] < ordered_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
move_on_dim_(i) &= ordered_access_idx[j] == ordered_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim;
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
@@ -612,9 +606,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_access_lengths[I0] - 1;
|
||||
@@ -623,10 +617,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
tmp = tmp * ordered_access_lengths[j] + ordered_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in Run(), if it has not being reset by
|
||||
@@ -638,19 +632,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
auto src_data_idx = container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
|
||||
return src_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step;
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; });
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step;
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
@@ -682,7 +674,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v2
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -739,8 +731,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpace::Lds,
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
@@ -797,9 +789,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
@@ -808,10 +800,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
@@ -824,11 +816,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
auto src_data_idx =
|
||||
container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
|
||||
return src_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
vector_type_maker_t<SrcData, SrcScalarPerVector> src_tmp_vector;
|
||||
@@ -852,18 +841,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim;
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim(i) &=
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim;
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
@@ -900,8 +889,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpace::Lds,
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
@@ -962,9 +951,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
@@ -973,10 +962,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
@@ -989,11 +978,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
auto dst_data_idx =
|
||||
container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
|
||||
return dst_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
vector_type_maker_t<DstData, DstScalarPerVector> dst_tmp_vector;
|
||||
@@ -1019,18 +1005,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim;
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim(i) &=
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim;
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
@@ -1108,9 +1094,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
@@ -1119,10 +1105,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
@@ -1134,19 +1120,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
|
||||
return src_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step;
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; });
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step;
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
@@ -1170,9 +1154,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
@@ -1181,10 +1165,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
@@ -1196,19 +1180,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
|
||||
return dst_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_scalar_per_access;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step;
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; });
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step;
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
@@ -1270,7 +1252,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
@@ -1357,9 +1339,6 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
|
||||
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
|
||||
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
// scalar per access of each dim
|
||||
constexpr auto src_scalar_per_access = generate_sequence_v2(
|
||||
[&](auto i) constexpr {
|
||||
|
||||
@@ -13,7 +13,7 @@ namespace ck {
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
@@ -77,8 +77,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpace::Lds,
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
@@ -140,9 +140,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
@@ -151,10 +151,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
@@ -167,11 +167,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
auto src_data_idx =
|
||||
container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
|
||||
return src_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
}();
|
||||
|
||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||
@@ -201,18 +198,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim;
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
move_on_dim_(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim(i) &=
|
||||
move_on_dim_(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim;
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
@@ -249,8 +246,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpace::Lds,
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
@@ -316,9 +313,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
@@ -327,10 +324,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
@@ -343,11 +340,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
auto dst_data_idx =
|
||||
container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_vector_tensor_lengths;
|
||||
|
||||
return dst_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_vector_tensor_lengths;
|
||||
}();
|
||||
|
||||
vector_type_maker_t<DstData, dst_vector_desc.GetElementSpaceSize()> dst_vector;
|
||||
@@ -379,18 +373,18 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim;
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
move_on_dim_(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim(i) &=
|
||||
move_on_dim_(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim;
|
||||
return move_on_dim_;
|
||||
}
|
||||
();
|
||||
|
||||
@@ -463,9 +457,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
@@ -474,10 +468,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
@@ -489,19 +483,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
|
||||
return src_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step;
|
||||
Index reset_src_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; });
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step_(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step;
|
||||
return reset_src_data_step_;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
@@ -520,9 +512,9 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep_;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
forward_sweep_(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
@@ -531,10 +523,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
forward_sweep_(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
return forward_sweep_;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
@@ -546,19 +538,17 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_vector_tensor_lengths;
|
||||
|
||||
return dst_data_idx;
|
||||
return container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_vector_tensor_lengths;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step;
|
||||
Index reset_dst_data_step_;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; });
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step_(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step;
|
||||
return reset_dst_data_step_;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
@@ -620,7 +610,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
StaticBuffer<AddressSpaceEnum_t::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_THREADWISE_GEMM_V3_HPP
|
||||
#define CK_THREADWISE_GEMM_V3_HPP
|
||||
#ifndef CK_THREADWISE_GEMM_DLOPS_V3_HPP
|
||||
#define CK_THREADWISE_GEMM_DLOPS_V3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "math.hpp"
|
||||
@@ -22,7 +22,7 @@ template <typename FloatA,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_km_kn_mn_v3
|
||||
struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
@@ -1,7 +1,7 @@
|
||||
#ifndef CK_AMD_BUFFER_ADDRESSING_V2_HPP
|
||||
#define CK_AMD_BUFFER_ADDRESSING_V2_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -33,175 +33,175 @@ __device__ int32x4_t make_wave_buffer_resource(T* p_wave, index_t data_space_siz
|
||||
|
||||
// load
|
||||
__device__ int8_t
|
||||
__llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
|
||||
llvm_amdgcn_raw_buffer_load_i8(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i8");
|
||||
|
||||
__device__ int8x2_t
|
||||
__llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
|
||||
llvm_amdgcn_raw_buffer_load_i8x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i8");
|
||||
|
||||
__device__ int8x4_t
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i8");
|
||||
|
||||
__device__ int16_t
|
||||
__llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
llvm_amdgcn_raw_buffer_load_i16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
__device__ int32_t
|
||||
__llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
llvm_amdgcn_raw_buffer_load_i32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.i32");
|
||||
|
||||
__device__ int32x2_t
|
||||
__llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
||||
llvm_amdgcn_raw_buffer_load_i32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2i32");
|
||||
|
||||
__device__ int32x4_t
|
||||
__llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
// half
|
||||
__device__ half_t
|
||||
__llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
|
||||
llvm_amdgcn_raw_buffer_load_fp16(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f16");
|
||||
|
||||
__device__ half2_t
|
||||
__llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
|
||||
llvm_amdgcn_raw_buffer_load_fp16x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f16");
|
||||
|
||||
__device__ half4_t
|
||||
__llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
|
||||
llvm_amdgcn_raw_buffer_load_fp16x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f16");
|
||||
|
||||
// float
|
||||
__device__ float
|
||||
__llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
|
||||
llvm_amdgcn_raw_buffer_load_fp32(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.f32");
|
||||
|
||||
__device__ float2_t
|
||||
__llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
|
||||
llvm_amdgcn_raw_buffer_load_fp32x2(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v2f32");
|
||||
|
||||
__device__ float4_t
|
||||
__llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(int32x4_t srsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.load.v4f32");
|
||||
|
||||
// store
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
|
||||
llvm_amdgcn_raw_buffer_store_i8(int8_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i8");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i8x2(int8x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i8");
|
||||
llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i8x4(int8x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i8");
|
||||
llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i16(int16_t vdata,
|
||||
llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i16");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i32(int32_t vdata,
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.i32");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i32x2(int32x2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2i32");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_i32x4(int32x4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
// half
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
|
||||
llvm_amdgcn_raw_buffer_store_fp16(half_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f16");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x2(half2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f16");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(half4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f16");
|
||||
// float
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp32(float vdata,
|
||||
llvm_amdgcn_raw_buffer_store_fp32(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
|
||||
|
||||
__device__ void
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.f32");
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp32x2(float2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v2f32");
|
||||
|
||||
__device__ void
|
||||
__llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t voffset,
|
||||
index_t soffset,
|
||||
index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32");
|
||||
|
||||
template <typename T, index_t N>
|
||||
__device__ typename vector_type<T, N>::type
|
||||
@@ -220,31 +220,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp32(
|
||||
return llvm_amdgcn_raw_buffer_load_fp32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp32x2(
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
return llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<float, 8> tmp;
|
||||
|
||||
tmp.AsType<float4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
tmp.AsType<float4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<float4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(float),
|
||||
0);
|
||||
|
||||
return tmp.AsType<float8_t>()(Number<0>{});
|
||||
}
|
||||
@@ -253,17 +253,17 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16(
|
||||
return llvm_amdgcn_raw_buffer_load_fp16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x2(
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
return llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
@@ -271,18 +271,18 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
#if 0
|
||||
vector_type<half_t, 8> tmp;
|
||||
|
||||
tmp.AsType<half4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
tmp.AsType<half4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<half4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
|
||||
llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
|
||||
return tmp.AsType<half8_t>()(Number<0>{});
|
||||
#else
|
||||
float4_t tmp = __llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<half8_t>(tmp);
|
||||
@@ -293,31 +293,31 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i32(
|
||||
return llvm_amdgcn_raw_buffer_load_i32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
return llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
return llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<int32_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
tmp.AsType<int32x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<int32x4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_load_i32x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int32_t),
|
||||
0);
|
||||
return tmp.AsType<int32x8_t>()(Number<0>{});
|
||||
}
|
||||
}
|
||||
@@ -325,16 +325,16 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
return __llvm_amdgcn_raw_buffer_load_i8(
|
||||
return llvm_amdgcn_raw_buffer_load_i8(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return __llvm_amdgcn_raw_buffer_load_i8x2(
|
||||
return llvm_amdgcn_raw_buffer_load_i8x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
#else
|
||||
int16_t tmp = __llvm_amdgcn_raw_buffer_load_i16(
|
||||
int16_t tmp = llvm_amdgcn_raw_buffer_load_i16(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x2_t>(tmp);
|
||||
@@ -343,10 +343,10 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
return __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
return llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
#else
|
||||
int32_t tmp = __llvm_amdgcn_raw_buffer_load_i32(
|
||||
int32_t tmp = llvm_amdgcn_raw_buffer_load_i32(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x4_t>(tmp);
|
||||
@@ -357,18 +357,18 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 8> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
return tmp.AsType<int8x8_t>()(Number<0>{});
|
||||
#else
|
||||
int32x2_t tmp = __llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
int32x2_t tmp = llvm_amdgcn_raw_buffer_load_i32x2(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x8_t>(tmp);
|
||||
@@ -379,30 +379,30 @@ amd_buffer_load_impl_v2(int32x4_t src_wave_buffer_resource,
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
vector_type<int8_t, 16> tmp;
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = __llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
tmp.AsType<int8x4_t>()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_i8x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<1>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 4 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<2>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(int8_t),
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 8 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
tmp.AsType<int8x4_t>()(Number<3>{}) =
|
||||
__llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(int8_t),
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_load_i8x4(src_wave_buffer_resource,
|
||||
src_thread_addr_offset,
|
||||
src_wave_addr_offset + 12 * sizeof(int8_t),
|
||||
0);
|
||||
|
||||
return tmp.AsType<int8x16_t>()(Number<0>{});
|
||||
#else
|
||||
int32x4_t tmp = __llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
int32x4_t tmp = llvm_amdgcn_raw_buffer_load_i32x4(
|
||||
src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0);
|
||||
|
||||
return as_type<int8x16_t>(tmp);
|
||||
@@ -428,61 +428,7 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int8_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
|
||||
llvm_amdgcn_raw_buffer_store_fp32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
@@ -490,94 +436,148 @@ __device__ void amd_buffer_store_impl_v2(const typename vector_type<T, N>::type
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
__llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
|
||||
llvm_amdgcn_raw_buffer_store_fp32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
__llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int32_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, int8_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_i8(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
llvm_amdgcn_raw_buffer_store_i8x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
llvm_amdgcn_raw_buffer_store_i16(as_type<int16_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_BUFFER_LOAD_STORE_ISSUE
|
||||
__llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_i8x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#else
|
||||
__llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_i32(as_type<int32_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
#endif
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_i32x2(as_type<int32x2_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 16)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_i32x4(as_type<int32x4_t>(src_thread_data),
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
}
|
||||
else if constexpr(is_same<T, half_t>::value)
|
||||
{
|
||||
if constexpr(N == 1)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
|
||||
llvm_amdgcn_raw_buffer_store_fp16(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 2)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data,
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
}
|
||||
else if constexpr(N == 8)
|
||||
{
|
||||
vector_type<half_t, 8> tmp{src_thread_data};
|
||||
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<0>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset,
|
||||
0);
|
||||
|
||||
__llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType<half4_t>()[Number<1>{}],
|
||||
dst_wave_buffer_resource,
|
||||
dst_thread_addr_offset,
|
||||
dst_wave_addr_offset + 4 * sizeof(half_t),
|
||||
0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#ifndef CK_AMD_DLOP_HPP
|
||||
#define CK_AMD_DLOP_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#ifndef CK_AMD_INLINE_ASM_HPP
|
||||
#define CK_AMD_INLINE_ASM_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
#ifndef CK_AMD_LLVM_INTRINSIC_HPP
|
||||
#define CK_AMD_LLVM_INTRINSIC_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ int32_t __llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
|
||||
__device__ int32_t llvm_amdgcn_readfirstlane_i32(int32_t i) __asm("llvm.amdgcn.readfirstlane");
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#ifndef CK_AMD_XDLOPS_HPP
|
||||
#define CK_AMD_XDLOPS_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
#include "data_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
|
||||
@@ -7,8 +7,9 @@
|
||||
#include "statically_indexed_array.hpp"
|
||||
#include "container_element_picker.hpp"
|
||||
#include "multi_index.hpp"
|
||||
#include "data_type_enum.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "float_type.hpp"
|
||||
#include "data_type_helper.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
|
||||
@@ -8,18 +8,13 @@
|
||||
#include "bfloat16_dev.hpp"
|
||||
|
||||
// address space for kernel parameter
|
||||
#define __CONSTANT__ __attribute__((address_space(4)))
|
||||
#define CONSTANT __attribute__((address_space(4)))
|
||||
|
||||
// device backend
|
||||
#define CK_DEVICE_BACKEND_AMD 1
|
||||
|
||||
// GPU ID
|
||||
#if 0
|
||||
#define CK_AMD_GPU_GFX906 1
|
||||
#elif 1
|
||||
#define CK_AMD_GPU_GFX908 1
|
||||
#elif 0
|
||||
#define CK_AMD_GPU_GFX1030 1
|
||||
// GPU target
|
||||
// should enable one and only one GPU target
|
||||
#if !(defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
|
||||
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A) || defined(CK_AMD_GPU_GFX1030))
|
||||
#error Need to define a single GPU target
|
||||
#endif
|
||||
|
||||
// HIP version
|
||||
@@ -36,7 +31,8 @@
|
||||
#endif
|
||||
|
||||
// buffer resourse
|
||||
#if defined(CK_AMD_GPU_GFX906) || defined(CK_AMD_GPU_GFX908)
|
||||
#if defined(CK_AMD_GPU_GFX803) || defined(CK_AMD_GPU_GFX900) || defined(CK_AMD_GPU_GFX906) || \
|
||||
defined(CK_AMD_GPU_GFX908) || defined(CK_AMD_GPU_GFX90A)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x00020000
|
||||
#elif defined(CK_AMD_GPU_GFX1030)
|
||||
#define CK_BUFFER_RESOURCE_3RD_DWORD 0x31014000
|
||||
@@ -50,10 +46,6 @@
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
// AMD DLOPS
|
||||
#ifndef CK_USE_AMD_DLOP
|
||||
#define CK_USE_AMD_DLOP 1
|
||||
@@ -78,14 +70,6 @@
|
||||
#define CK_USE_AMD_XDLOPS 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_XDLOPS_INLINE_ASM
|
||||
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_XDLOPS_EMULATE
|
||||
#define CK_USE_AMD_XDLOPS_EMULATE 0 // For internal debug purposes
|
||||
#endif
|
||||
|
||||
// block synchronization only s_wait lgkmcnt(0), not vmcnt(0)
|
||||
#ifndef CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
#define CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM 1
|
||||
@@ -104,18 +88,6 @@
|
||||
#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
|
||||
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_OUTPUT_SKIP_OUT_OF_BOUND_CHECK 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK
|
||||
#define CK_EXPERIMENTAL_IMPLICIT_GEMM_BACKWARD_DATA_V4R1_INPUT_SKIP_OUT_OF_BOUND_CHECK 0
|
||||
#endif
|
||||
|
||||
// pass tensor descriptor by value or void*
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
|
||||
@@ -131,17 +103,6 @@
|
||||
#define CK_HACK_DYNAMIC_MERGE_CALCULATE_IDX_DIFF_LOW_CONST_USE_AMD_GCN_READ_FIRST_LANE 0
|
||||
#endif
|
||||
|
||||
// workaround: put all workaround here
|
||||
// workaround for unnecessary VGPR <--> AGPR data movement when using mfma LLVM intrinsic
|
||||
#ifndef CK_WORKAROUND_SWDEV_229564
|
||||
#define CK_WORKAROUND_SWDEV_229564 1
|
||||
#endif
|
||||
|
||||
// workaround for accvgpr over-allocation
|
||||
#ifndef CK_WORKAROUND_SWDEV_241664
|
||||
#define CK_WORKAROUND_SWDEV_241664 1
|
||||
#endif
|
||||
|
||||
// workaround for compiler crash when compiling recursive lambda
|
||||
#ifndef CK_WORKAROUND_SWDEV_275126
|
||||
#define CK_WORKAROUND_SWDEV_275126 1
|
||||
@@ -159,7 +120,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
enum AddressSpaceEnum_t
|
||||
{
|
||||
Generic,
|
||||
Global,
|
||||
@@ -168,7 +129,7 @@ enum AddressSpace
|
||||
Vgpr
|
||||
};
|
||||
|
||||
enum InMemoryDataOperation
|
||||
enum InMemoryDataOperationEnum_t
|
||||
{
|
||||
Set,
|
||||
AtomicAdd
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
19
composable_kernel/include/utility/data_type_enum.hpp
Normal file
19
composable_kernel/include/utility/data_type_enum.hpp
Normal file
@@ -0,0 +1,19 @@
|
||||
#ifndef CK_DATA_TYPE_ENUM_HPP
|
||||
#define CK_DATA_TYPE_ENUM_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this enumerate should be synchronized with include/miopen.h
|
||||
typedef enum {
|
||||
Half = 0,
|
||||
Float = 1,
|
||||
Int32 = 2,
|
||||
Int8 = 3,
|
||||
Int8x4 = 4,
|
||||
BFloat16 = 5,
|
||||
Double = 6,
|
||||
Unknown = 100,
|
||||
} DataTypeEnum_t;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
76
composable_kernel/include/utility/data_type_helper.hpp
Normal file
76
composable_kernel/include/utility/data_type_helper.hpp
Normal file
@@ -0,0 +1,76 @@
|
||||
#ifndef CK_DATA_TYPE_HELPER_HPP
|
||||
#define CK_DATA_TYPE_HELPER_HPP
|
||||
|
||||
#include "data_type.hpp"
|
||||
#include "data_type_enum.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <DataTypeEnum_t DataTypeEnum>
|
||||
struct get_datatype_from_enum;
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Int8>
|
||||
{
|
||||
using type = int8_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Int32>
|
||||
{
|
||||
using type = int32_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Half>
|
||||
{
|
||||
using type = half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Float>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_from_enum<DataTypeEnum_t::Double>
|
||||
{
|
||||
using type = double;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct get_datatype_enum_from_type;
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<int8_t>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int8;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<int32_t>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Int32;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<half_t>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Half;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<float>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_datatype_enum_from_type<double>
|
||||
{
|
||||
static constexpr DataTypeEnum_t value = DataTypeEnum_t::Double;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -5,7 +5,7 @@ namespace ck {
|
||||
|
||||
#include "amd_buffer_addressing_v2.hpp"
|
||||
|
||||
template <AddressSpace BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, typename ElementSpaceSize>
|
||||
struct DynamicBuffer
|
||||
{
|
||||
using type = T;
|
||||
@@ -18,7 +18,7 @@ struct DynamicBuffer
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
@@ -32,7 +32,7 @@ struct DynamicBuffer
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr const auto Get(index_t i, bool is_valid_offset) const
|
||||
__host__ __device__ constexpr auto Get(index_t i, bool is_valid_offset) const
|
||||
{
|
||||
// X contains multiple T
|
||||
constexpr index_t scalar_per_t_vector =
|
||||
@@ -46,7 +46,7 @@ struct DynamicBuffer
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpace::Global)
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
return amd_buffer_load_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
@@ -80,7 +80,7 @@ struct DynamicBuffer
|
||||
|
||||
constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector;
|
||||
|
||||
if constexpr(GetAddressSpace() == AddressSpace::Global)
|
||||
if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Global)
|
||||
{
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
amd_buffer_store_v2<remove_cv_t<remove_reference_t<T>>, t_per_x>(
|
||||
@@ -92,14 +92,15 @@ struct DynamicBuffer
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(GetAddressSpace() == AddressSpace::Lds)
|
||||
else if constexpr(GetAddressSpace() == AddressSpaceEnum_t::Lds)
|
||||
{
|
||||
if(is_valid_offset)
|
||||
{
|
||||
#if !CK_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
#else
|
||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
|
||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into
|
||||
// inefficient
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
@@ -119,7 +120,8 @@ struct DynamicBuffer
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
|
||||
"wrong! not implemented for this combination, please add implementation");
|
||||
"wrong! not implemented for this combination, please add "
|
||||
"implementation");
|
||||
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
|
||||
@@ -194,7 +196,7 @@ struct DynamicBuffer
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return true; }
|
||||
};
|
||||
|
||||
template <AddressSpace BufferAddressSpace = AddressSpace::Generic,
|
||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
||||
typename T,
|
||||
typename ElementSpaceSize>
|
||||
__host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize element_space_size)
|
||||
|
||||
@@ -1,999 +0,0 @@
|
||||
#ifndef CK_FLOAT_TYPE_AMD_HPP
|
||||
#define CK_FLOAT_TYPE_AMD_HPP
|
||||
|
||||
#include "statically_indexed_array.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
using half_t = _Float16;
|
||||
|
||||
// vector_type
|
||||
template <typename T, index_t N>
|
||||
struct vector_type;
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
|
||||
// vectors"
|
||||
template <typename T, index_t V, index_t N>
|
||||
struct vector_type<T __attribute__((ext_vector_type(V))), N>;
|
||||
|
||||
// Caution: DO NOT REMOVE
|
||||
// intentionally have only declaration but no definition to cause compilation failure when trying to
|
||||
// instantiate this template. The purpose is to catch user's mistake when trying to make "vector of
|
||||
// vectors"
|
||||
template <typename T, index_t V, index_t N>
|
||||
struct vector_type<vector_type<T, V>, N>;
|
||||
|
||||
// vector_type_maker
|
||||
// This is the right way to handle "vector of vectors": making a bigger vector instead
|
||||
template <typename T, index_t N>
|
||||
struct vector_type_maker
|
||||
{
|
||||
using type = vector_type<T, N>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<T __attribute__((ext_vector_type(N1))), N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N0, index_t N1>
|
||||
struct vector_type_maker<vector_type<T, N1>, N0>
|
||||
{
|
||||
using type = vector_type<T, N0 * N1>;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
|
||||
|
||||
template <typename T, index_t N>
|
||||
__host__ __device__ constexpr auto make_vector_type(Number<N>)
|
||||
{
|
||||
return typename vector_type_maker<T, N>::type{};
|
||||
}
|
||||
|
||||
// scalar_type
|
||||
template <typename TV>
|
||||
struct scalar_type;
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<T __attribute__((ext_vector_type(N)))>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
template <typename T, index_t N>
|
||||
struct scalar_type<vector_type<T, N>>
|
||||
{
|
||||
using type = T;
|
||||
static constexpr index_t vector_size = N;
|
||||
};
|
||||
|
||||
//
|
||||
template <>
|
||||
struct scalar_type<float>
|
||||
{
|
||||
using type = float;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<half_t>
|
||||
{
|
||||
using type = half_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<ushort>
|
||||
{
|
||||
using type = ushort;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int32_t>
|
||||
{
|
||||
using type = int32_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct scalar_type<int8_t>
|
||||
{
|
||||
using type = int8_t;
|
||||
static constexpr index_t vector_size = 1;
|
||||
};
|
||||
|
||||
//
|
||||
template <typename T>
|
||||
struct vector_type<T, 1>
|
||||
{
|
||||
using d1_t = T;
|
||||
using type = d1_t;
|
||||
|
||||
union
|
||||
{
|
||||
T d1_;
|
||||
StaticallyIndexedArray<T, 1> d1x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value, "wrong!");
|
||||
|
||||
return data_.d1x1_;
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value, "wrong!");
|
||||
|
||||
return data_.d1x1_;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 2>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
|
||||
using type = d2_t;
|
||||
|
||||
union
|
||||
{
|
||||
d2_t d2_;
|
||||
StaticallyIndexedArray<d1_t, 2> d1x2_;
|
||||
StaticallyIndexedArray<d2_t, 1> d2x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value, "wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 4>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
using type = d4_t;
|
||||
|
||||
union
|
||||
{
|
||||
d4_t d4_;
|
||||
StaticallyIndexedArray<d1_t, 4> d1x4_;
|
||||
StaticallyIndexedArray<d2_t, 2> d2x2_;
|
||||
StaticallyIndexedArray<d4_t, 1> d4x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 8>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
|
||||
using type = d8_t;
|
||||
|
||||
union
|
||||
{
|
||||
d8_t d8_;
|
||||
StaticallyIndexedArray<d1_t, 8> d1x8_;
|
||||
StaticallyIndexedArray<d2_t, 4> d2x4_;
|
||||
StaticallyIndexedArray<d4_t, 2> d4x2_;
|
||||
StaticallyIndexedArray<d8_t, 1> d8x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 16>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d16_t __attribute__((ext_vector_type(16)));
|
||||
|
||||
using type = d16_t;
|
||||
|
||||
union
|
||||
{
|
||||
d16_t d16_;
|
||||
StaticallyIndexedArray<d1_t, 16> d1x16_;
|
||||
StaticallyIndexedArray<d2_t, 8> d2x8_;
|
||||
StaticallyIndexedArray<d4_t, 4> d4x4_;
|
||||
StaticallyIndexedArray<d8_t, 2> d8x2_;
|
||||
StaticallyIndexedArray<d16_t, 1> d16x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 32>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d16_t __attribute__((ext_vector_type(16)));
|
||||
typedef T d32_t __attribute__((ext_vector_type(32)));
|
||||
|
||||
using type = d32_t;
|
||||
|
||||
union
|
||||
{
|
||||
d32_t d32_;
|
||||
StaticallyIndexedArray<d1_t, 32> d1x32_;
|
||||
StaticallyIndexedArray<d2_t, 16> d2x16_;
|
||||
StaticallyIndexedArray<d4_t, 8> d4x8_;
|
||||
StaticallyIndexedArray<d8_t, 4> d8x4_;
|
||||
StaticallyIndexedArray<d16_t, 2> d16x2_;
|
||||
StaticallyIndexedArray<d32_t, 1> d32x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 64>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d16_t __attribute__((ext_vector_type(16)));
|
||||
typedef T d32_t __attribute__((ext_vector_type(32)));
|
||||
typedef T d64_t __attribute__((ext_vector_type(64)));
|
||||
|
||||
using type = d64_t;
|
||||
|
||||
union
|
||||
{
|
||||
d64_t d64_;
|
||||
StaticallyIndexedArray<d1_t, 64> d1x64_;
|
||||
StaticallyIndexedArray<d2_t, 32> d2x32_;
|
||||
StaticallyIndexedArray<d4_t, 16> d4x16_;
|
||||
StaticallyIndexedArray<d8_t, 8> d8x8_;
|
||||
StaticallyIndexedArray<d16_t, 4> d16x4_;
|
||||
StaticallyIndexedArray<d32_t, 2> d32x2_;
|
||||
StaticallyIndexedArray<d64_t, 1> d64x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 128>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d16_t __attribute__((ext_vector_type(16)));
|
||||
typedef T d32_t __attribute__((ext_vector_type(32)));
|
||||
typedef T d64_t __attribute__((ext_vector_type(64)));
|
||||
typedef T d128_t __attribute__((ext_vector_type(128)));
|
||||
|
||||
using type = d128_t;
|
||||
|
||||
union
|
||||
{
|
||||
d128_t d128_;
|
||||
StaticallyIndexedArray<d1_t, 128> d1x128_;
|
||||
StaticallyIndexedArray<d2_t, 64> d2x64_;
|
||||
StaticallyIndexedArray<d4_t, 32> d4x32_;
|
||||
StaticallyIndexedArray<d8_t, 16> d8x16_;
|
||||
StaticallyIndexedArray<d16_t, 8> d16x8_;
|
||||
StaticallyIndexedArray<d32_t, 4> d32x4_;
|
||||
StaticallyIndexedArray<d64_t, 2> d64x2_;
|
||||
StaticallyIndexedArray<d128_t, 1> d128x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x128_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d128_t>::value)
|
||||
{
|
||||
return data_.d128x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
|
||||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
|
||||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value || is_same<X, d128_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x128_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d128_t>::value)
|
||||
{
|
||||
return data_.d128x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct vector_type<T, 256>
|
||||
{
|
||||
using d1_t = T;
|
||||
typedef T d2_t __attribute__((ext_vector_type(2)));
|
||||
typedef T d4_t __attribute__((ext_vector_type(4)));
|
||||
typedef T d8_t __attribute__((ext_vector_type(8)));
|
||||
typedef T d16_t __attribute__((ext_vector_type(16)));
|
||||
typedef T d32_t __attribute__((ext_vector_type(32)));
|
||||
typedef T d64_t __attribute__((ext_vector_type(64)));
|
||||
typedef T d128_t __attribute__((ext_vector_type(128)));
|
||||
typedef T d256_t __attribute__((ext_vector_type(256)));
|
||||
|
||||
using type = d256_t;
|
||||
|
||||
union
|
||||
{
|
||||
d256_t d256_;
|
||||
StaticallyIndexedArray<d1_t, 256> d1x256_;
|
||||
StaticallyIndexedArray<d2_t, 128> d2x128_;
|
||||
StaticallyIndexedArray<d4_t, 64> d4x64_;
|
||||
StaticallyIndexedArray<d8_t, 32> d8x32_;
|
||||
StaticallyIndexedArray<d16_t, 16> d16x16_;
|
||||
StaticallyIndexedArray<d32_t, 8> d32x8_;
|
||||
StaticallyIndexedArray<d64_t, 4> d64x4_;
|
||||
StaticallyIndexedArray<d128_t, 2> d128x2_;
|
||||
StaticallyIndexedArray<d256_t, 1> d256x1_;
|
||||
} data_;
|
||||
|
||||
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
|
||||
|
||||
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr const auto& AsType() const
|
||||
{
|
||||
static_assert(
|
||||
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x256_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x128_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d128_t>::value)
|
||||
{
|
||||
return data_.d128x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d256_t>::value)
|
||||
{
|
||||
return data_.d256x1_;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename X>
|
||||
__host__ __device__ constexpr auto& AsType()
|
||||
{
|
||||
static_assert(
|
||||
is_same<X, d1_t>::value || is_same<X, d2_t>::value || is_same<X, d4_t>::value ||
|
||||
is_same<X, d8_t>::value || is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
|
||||
is_same<X, d64_t>::value || is_same<X, d128_t>::value || is_same<X, d256_t>::value,
|
||||
"wrong!");
|
||||
|
||||
if constexpr(is_same<X, d1_t>::value)
|
||||
{
|
||||
return data_.d1x256_;
|
||||
}
|
||||
else if constexpr(is_same<X, d2_t>::value)
|
||||
{
|
||||
return data_.d2x128_;
|
||||
}
|
||||
else if constexpr(is_same<X, d4_t>::value)
|
||||
{
|
||||
return data_.d4x64_;
|
||||
}
|
||||
else if constexpr(is_same<X, d8_t>::value)
|
||||
{
|
||||
return data_.d8x32_;
|
||||
}
|
||||
else if constexpr(is_same<X, d16_t>::value)
|
||||
{
|
||||
return data_.d16x16_;
|
||||
}
|
||||
else if constexpr(is_same<X, d32_t>::value)
|
||||
{
|
||||
return data_.d32x8_;
|
||||
}
|
||||
else if constexpr(is_same<X, d64_t>::value)
|
||||
{
|
||||
return data_.d64x4_;
|
||||
}
|
||||
else if constexpr(is_same<X, d128_t>::value)
|
||||
{
|
||||
return data_.d128x2_;
|
||||
}
|
||||
else if constexpr(is_same<X, d256_t>::value)
|
||||
{
|
||||
return data_.d256x1_;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// fp32
|
||||
using float2_t = typename vector_type<float, 2>::type;
|
||||
using float4_t = typename vector_type<float, 4>::type;
|
||||
using float8_t = typename vector_type<float, 8>::type;
|
||||
using float16_t = typename vector_type<float, 16>::type;
|
||||
using float32_t = typename vector_type<float, 32>::type;
|
||||
using float64_t = typename vector_type<float, 64>::type;
|
||||
|
||||
// fp16
|
||||
using half2_t = typename vector_type<half_t, 2>::type;
|
||||
using half4_t = typename vector_type<half_t, 4>::type;
|
||||
using half8_t = typename vector_type<half_t, 8>::type;
|
||||
using half16_t = typename vector_type<half_t, 16>::type;
|
||||
using half32_t = typename vector_type<half_t, 32>::type;
|
||||
using half64_t = typename vector_type<half_t, 64>::type;
|
||||
|
||||
// bfp16
|
||||
using ushort2_t = typename vector_type<ushort, 2>::type;
|
||||
using ushort4_t = typename vector_type<ushort, 4>::type;
|
||||
using ushort8_t = typename vector_type<ushort, 8>::type;
|
||||
using ushort16_t = typename vector_type<ushort, 16>::type;
|
||||
using ushort32_t = typename vector_type<ushort, 32>::type;
|
||||
using ushort64_t = typename vector_type<ushort, 64>::type;
|
||||
|
||||
// i32
|
||||
using int32x2_t = typename vector_type<int32_t, 2>::type;
|
||||
using int32x4_t = typename vector_type<int32_t, 4>::type;
|
||||
using int32x8_t = typename vector_type<int32_t, 8>::type;
|
||||
using int32x16_t = typename vector_type<int32_t, 16>::type;
|
||||
using int32x32_t = typename vector_type<int32_t, 32>::type;
|
||||
using int32x64_t = typename vector_type<int32_t, 64>::type;
|
||||
|
||||
// i8
|
||||
using int8x2_t = typename vector_type<int8_t, 2>::type;
|
||||
using int8x4_t = typename vector_type<int8_t, 4>::type;
|
||||
using int8x8_t = typename vector_type<int8_t, 8>::type;
|
||||
using int8x16_t = typename vector_type<int8_t, 16>::type;
|
||||
using int8x32_t = typename vector_type<int8_t, 32>::type;
|
||||
using int8x64_t = typename vector_type<int8_t, 64>::type;
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
{
|
||||
template <typename X>
|
||||
__device__ T operator()(X x) const
|
||||
{
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
template <>
|
||||
__device__ float type_convert<float>::operator()<ushort>(ushort x) const
|
||||
{
|
||||
return bfloat16_to_float(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
template <>
|
||||
__device__ ushort type_convert<ushort>::operator()<float>(float x) const
|
||||
{
|
||||
return float_to_bfloat16(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct inner_product_with_conversion
|
||||
{
|
||||
static constexpr auto convert = type_convert<T>();
|
||||
|
||||
template <typename X, index_t N>
|
||||
__device__ T operator()(typename vector_type<X, N>::type a,
|
||||
typename vector_type<X, N>::type b) const
|
||||
{
|
||||
const vector_type<X, N> a_vector{a};
|
||||
const vector_type<X, N> b_vector{b};
|
||||
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, N, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.Scalars()[i]) * convert(b_vector.Scalars()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(float_t a, float_t b) const { return convert(a) * convert(b); }
|
||||
|
||||
__device__ T operator()(int8x4_t a, int8x4_t b) const
|
||||
{
|
||||
const vector_type<int8_t, 4> a_vector{a};
|
||||
const vector_type<int8_t, 4> b_vector{b};
|
||||
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, 4, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(int8x8_t a, int8x8_t b) const
|
||||
{
|
||||
const vector_type<int8_t, 8> a_vector{a};
|
||||
const vector_type<int8_t, 8> b_vector{b};
|
||||
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, 8, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(int8x16_t a, int8x16_t b) const
|
||||
{
|
||||
const vector_type<int8_t, 16> a_vector{a};
|
||||
const vector_type<int8_t, 16> b_vector{b};
|
||||
|
||||
T acc = 0;
|
||||
|
||||
static_for<0, 16, 1>{}([&](auto i) {
|
||||
acc += convert(a_vector.AsType<int8_t>()[i]) * convert(b_vector.AsType<int8_t>()[i]);
|
||||
});
|
||||
|
||||
return acc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -127,7 +127,8 @@ struct MagicDivision
|
||||
DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift)
|
||||
{
|
||||
uint32_t dividend_u32 = as_type<uint32_t>(dividend_i32);
|
||||
uint32_t tmp = ((uint64_t)dividend_u32 * (uint64_t)multiplier) >> 32;
|
||||
uint32_t tmp =
|
||||
(static_cast<uint64_t>(dividend_u32) * static_cast<uint64_t>(multiplier)) >> 32;
|
||||
return (tmp + dividend_u32) >> shift;
|
||||
}
|
||||
#else
|
||||
|
||||
@@ -150,7 +150,15 @@ __host__ __device__ constexpr auto min(X x, Ys... ys)
|
||||
// greatest common divisor, aka highest common factor
|
||||
__host__ __device__ constexpr index_t gcd(index_t x, index_t y)
|
||||
{
|
||||
if(x == y || x == 0)
|
||||
if(x < 0)
|
||||
{
|
||||
return gcd(-x, y);
|
||||
}
|
||||
else if(y < 0)
|
||||
{
|
||||
return gcd(x, -y);
|
||||
}
|
||||
else if(x == y || x == 0)
|
||||
{
|
||||
return y;
|
||||
}
|
||||
@@ -160,11 +168,11 @@ __host__ __device__ constexpr index_t gcd(index_t x, index_t y)
|
||||
}
|
||||
else if(x > y)
|
||||
{
|
||||
return gcd(x - y, y);
|
||||
return gcd(x % y, y);
|
||||
}
|
||||
else
|
||||
{
|
||||
return gcd(x, y - x);
|
||||
return gcd(x, y % x);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -181,7 +189,7 @@ template <typename X,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, ys...);
|
||||
return gcd(x, gcd(ys...));
|
||||
}
|
||||
|
||||
// least common multiple
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <AddressSpace BufferAddressSpace, typename T, index_t N>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace, typename T, index_t N>
|
||||
struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
{
|
||||
using type = T;
|
||||
@@ -13,7 +13,7 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
|
||||
__host__ __device__ constexpr StaticBuffer() : base{} {}
|
||||
|
||||
__host__ __device__ static constexpr AddressSpace GetAddressSpace()
|
||||
__host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace()
|
||||
{
|
||||
return BufferAddressSpace;
|
||||
}
|
||||
@@ -23,7 +23,9 @@ struct StaticBuffer : public StaticallyIndexedArray<T, N>
|
||||
__host__ __device__ static constexpr bool IsDynamicBuffer() { return false; }
|
||||
};
|
||||
|
||||
template <AddressSpace BufferAddressSpace = AddressSpace::Generic, typename T, index_t N>
|
||||
template <AddressSpaceEnum_t BufferAddressSpace = AddressSpaceEnum_t::Generic,
|
||||
typename T,
|
||||
index_t N>
|
||||
__host__ __device__ constexpr auto make_static_buffer(Number<N>)
|
||||
{
|
||||
return StaticBuffer<BufferAddressSpace, T, N>{};
|
||||
|
||||
@@ -5,8 +5,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ void __llvm_amdgcn_s_barrier() __asm("llvm.amdgcn.s.barrier");
|
||||
|
||||
__device__ void block_sync_lds()
|
||||
{
|
||||
#if CK_BLOCK_SYNC_LDS_WITHOUT_SYNC_VMEM
|
||||
@@ -15,11 +13,9 @@ __device__ void block_sync_lds()
|
||||
s_barrier \
|
||||
" ::);
|
||||
#else
|
||||
__llvm_amdgcn_s_barrier();
|
||||
__syncthreads();
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void block_sync_lds_vmem() { __llvm_amdgcn_s_barrier(); }
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -1,34 +0,0 @@
|
||||
#ifndef CK_TYPE_HELPER_HPP
|
||||
#define CK_TYPE_HELPER_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <char tid>
|
||||
struct get_type_from_type_id
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_type_from_type_id<'H'>
|
||||
{
|
||||
using type = half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_type_from_type_id<'F'>
|
||||
{
|
||||
using type = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct get_type_from_type_id<'D'>
|
||||
{
|
||||
using type = double;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -1,15 +1,18 @@
|
||||
#include "common_header.hpp"
|
||||
#include "type_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v1r2.hpp"
|
||||
#include "gridwise_dynamic_gemm_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
|
||||
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
|
||||
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
@@ -61,7 +64,8 @@ constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDs
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||
|
||||
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw_prepare(
|
||||
extern "C" __global__ void
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw_prepare(
|
||||
int n,
|
||||
int c,
|
||||
int hi,
|
||||
@@ -147,48 +151,48 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v4r4_nchw_k
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
auto a_k_m0_m1_grid_desc = GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
auto b_k_n0_n1_grid_desc = GridwiseGemm::MakeBKN0N1GridDescriptor(b_k_n_grid_desc);
|
||||
@@ -212,14 +216,14 @@ extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw(
|
||||
dynamic_convolution_forward_implicit_gemm_v4r4_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k_m0_m1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k_n0_n1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
const void CONSTANT* p_a_k_m0_m1_grid_desc,
|
||||
const void CONSTANT* p_b_k_n0_n1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -283,48 +287,48 @@ extern "C" __global__ void
|
||||
using BGridMoveSliceWindowIteratorHacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>;
|
||||
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
GridwiseDynamicGemmDlops_km_kn_mn_v1r2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperationEnum_t::Set, /* ToDo tunable */
|
||||
AKMGridDesc,
|
||||
BKNGridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterLengths_K_M0_M1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_M1,
|
||||
AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterLengths_K_N0_N1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_N1,
|
||||
BThreadTransferSrcResetCoordinateAfterRun,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
constexpr auto a_k_m0_m1_grid_desc_tmp =
|
||||
GridwiseGemm::MakeAKM0M1GridDescriptor(a_k_m_grid_desc);
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "common_header.hpp"
|
||||
#include "type_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
@@ -7,9 +6,13 @@
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
|
||||
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
|
||||
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
@@ -149,7 +152,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw_prepare(
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
@@ -213,10 +216,10 @@ extern "C" __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -286,7 +289,7 @@ extern "C" __global__ void
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include "common_header.hpp"
|
||||
#include "type_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_xdlops_v2r3.hpp"
|
||||
@@ -7,9 +6,13 @@
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
|
||||
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
|
||||
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_CONV_COMPTYPE)>::type;
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
@@ -149,7 +152,7 @@ dynamic_convolution_forward_implicit_gemm_v4r4_xdlops_nhwc_kyxc_nhwk_prepare(
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
@@ -213,10 +216,10 @@ extern "C" __global__ void
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k0_m_k1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k0_n_k1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
const void CONSTANT* p_a_k0_m_k1_grid_desc,
|
||||
const void CONSTANT* p_b_k0_n_k1_grid_desc,
|
||||
const void CONSTANT* p_c_m0_m1_m2_n_grid_desc,
|
||||
const void CONSTANT* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
@@ -287,7 +290,7 @@ extern "C" __global__ void
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
|
||||
@@ -1,31 +1,34 @@
|
||||
#include "common_header.hpp"
|
||||
#include "type_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_v1r2.hpp"
|
||||
#include "gridwise_dynamic_contraction_dlops_v1r2.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v6r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using FloatAB = typename get_type_from_type_id<static_cast<char>(CK_PARAM_IN_WEI_DATATYPE)>::type;
|
||||
using FloatAcc = typename get_type_from_type_id<static_cast<char>(CK_PARAM_ACC_DATATYPE)>::type;
|
||||
using FloatC = typename get_type_from_type_id<static_cast<char>(CK_PARAM_OUT_DATATYPE)>::type;
|
||||
constexpr DataTypeEnum_t ABDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_ABDataTypeEnum);
|
||||
constexpr DataTypeEnum_t AccDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_AccDataTypeEnum);
|
||||
constexpr DataTypeEnum_t CDataTypeEnum = static_cast<DataTypeEnum_t>(CK_PARAM_CDataTypeEnum);
|
||||
|
||||
using FloatAB = typename get_datatype_from_enum<ABDataTypeEnum>::type;
|
||||
using FloatAcc = typename get_datatype_from_enum<AccDataTypeEnum>::type;
|
||||
using FloatC = typename get_datatype_from_enum<CDataTypeEnum>::type;
|
||||
|
||||
constexpr index_t BlockSize = CK_PARAM_BlockSize;
|
||||
|
||||
constexpr auto GN0 = Number<CK_PARAM_GN0>{};
|
||||
constexpr auto GK1 = Number<CK_PARAM_GK1>{};
|
||||
|
||||
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
|
||||
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
|
||||
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
|
||||
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
|
||||
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
|
||||
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
|
||||
constexpr index_t BM10BN10ThreadClusterBM100 = CK_PARAM_BM10BN10ThreadClusterBM100;
|
||||
constexpr index_t BM10BN10ThreadClusterBN100 = CK_PARAM_BM10BN10ThreadClusterBN100;
|
||||
constexpr index_t BM10BN10ThreadClusterBM101 = CK_PARAM_BM10BN10ThreadClusterBM101;
|
||||
constexpr index_t BM10BN10ThreadClusterBN101 = CK_PARAM_BM10BN10ThreadClusterBN101;
|
||||
constexpr index_t GM1PerBlockGM11 = CK_PARAM_GM1PerBlockGM11;
|
||||
constexpr index_t GN1PerBlockGN11 = CK_PARAM_GN1PerBlockGN11;
|
||||
constexpr index_t GK0PerBlock = CK_PARAM_GK0PerBlock;
|
||||
|
||||
constexpr index_t BM1PerThreadBM11 = CK_PARAM_BM1PerThreadBM11;
|
||||
constexpr index_t BN1PerThreadBN11 = CK_PARAM_BN1PerThreadBN11;
|
||||
constexpr index_t BK0PerThread = CK_PARAM_BK0PerThread;
|
||||
|
||||
using BM10BN10ThreadClusterBM10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBM10Xs>;
|
||||
using BM10BN10ThreadClusterBN10Xs = Sequence<CK_PARAM_BM10BN10ThreadClusterBN10Xs>;
|
||||
|
||||
using ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 =
|
||||
Sequence<CK_PARAM_ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1>;
|
||||
@@ -55,29 +58,26 @@ using CThreadTransferSrcDstAccessOrder = Sequence<3, 4, 5, 0, 1, 2>
|
||||
constexpr index_t CThreadTransferSrcDstVectorDim = 5;
|
||||
constexpr index_t CThreadTransferDstScalarPerVector = CK_PARAM_CThreadTransferDstScalarPerVector;
|
||||
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HAS_MAIN_KBLOCK_LOOP);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HAS_DOUBLE_TAIL_KBLOCK_LOOP);
|
||||
constexpr bool HasMainKBlockLoop = static_cast<bool>(CK_PARAM_HasMainKBlockLoop);
|
||||
constexpr bool HasDoubleTailKBlockLoop = static_cast<bool>(CK_PARAM_HasDoubleTailKBlockLoop);
|
||||
|
||||
extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw_prepare(
|
||||
index_t N,
|
||||
index_t C,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t K,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t ConvStrideH,
|
||||
index_t ConvStrideW,
|
||||
index_t ConvDilationH,
|
||||
index_t ConvDilationW,
|
||||
index_t InLeftPadH,
|
||||
index_t InLeftPadW,
|
||||
index_t InRightPadH,
|
||||
index_t InRightPadW,
|
||||
void* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
void* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
void* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
void* p_c_grid_block_cluster_blockid_to_gm10_gn10)
|
||||
extern "C" __global__ void
|
||||
dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw_prepare(index_t N,
|
||||
index_t C,
|
||||
index_t Hi,
|
||||
index_t Wi,
|
||||
index_t K,
|
||||
index_t Y,
|
||||
index_t X,
|
||||
index_t ConvStrideH,
|
||||
index_t ConvStrideW,
|
||||
index_t ConvDilationH,
|
||||
index_t ConvDilationW,
|
||||
index_t InLeftPadH,
|
||||
index_t InLeftPadW,
|
||||
index_t InRightPadH,
|
||||
index_t InRightPadW,
|
||||
void* p_desc_tuple)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -160,12 +160,12 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
@@ -175,10 +175,8 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -202,47 +200,36 @@ extern "C" __global__ void dynamic_convolution_forward_implicit_gemm_v6r1_nchw_k
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(a_grid_desc_gk0_gm0_gm1_gk1);
|
||||
auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(b_grid_desc_gk0_gn0_gn1_gk1);
|
||||
auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1);
|
||||
|
||||
if(hipThreadIdx_x == 0)
|
||||
if(get_block_1d_id() == 0 && get_thread_local_1d_id() == 0)
|
||||
{
|
||||
*static_cast<decltype(a_grid_desc_gk0_gm0_gm10_gm11_gk1)*>(
|
||||
p_a_grid_desc_gk0_gm0_gm10_gm11_gk1) = a_grid_desc_gk0_gm0_gm10_gm11_gk1;
|
||||
*static_cast<decltype(b_grid_desc_gk0_gn0_gn10_gn11_gk1)*>(
|
||||
p_b_grid_desc_gk0_gn0_gn10_gn11_gk1) = b_grid_desc_gk0_gn0_gn10_gn11_gk1;
|
||||
*static_cast<decltype(c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1)*>(
|
||||
p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1) = c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1;
|
||||
*static_cast<decltype(c_grid_block_cluster_blockid_to_gm10_gn10)*>(
|
||||
p_c_grid_block_cluster_blockid_to_gm10_gn10) =
|
||||
c_grid_block_cluster_blockid_to_gm10_gn10;
|
||||
};
|
||||
auto desc_tuple =
|
||||
make_tuple(GridwiseContraction::MakeAGridDescriptor_GK0_GM0_GM10_GM11_GK1(
|
||||
a_grid_desc_gk0_gm0_gm1_gk1),
|
||||
GridwiseContraction::MakeBGridDescriptor_GK0_GN0_GN10_GN11_GK1(
|
||||
b_grid_desc_gk0_gn0_gn1_gk1),
|
||||
GridwiseContraction::MakeCGridDescriptor_GM10_BM0_BM1_GN10_BN0_BN1(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1),
|
||||
GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
|
||||
*static_cast<decltype(desc_tuple)*>(p_desc_tuple) = desc_tuple;
|
||||
}
|
||||
};
|
||||
|
||||
extern "C" __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
dynamic_convolution_forward_implicit_gemm_v6r1_nchw_kcyx_nkhw(
|
||||
dynamic_convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_grid_desc_gk0_gm0_gm10_gm11_gk1,
|
||||
const void __CONSTANT__* p_b_grid_desc_gk0_gn0_gn10_gn11_gk1,
|
||||
const void __CONSTANT__* p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1,
|
||||
const void __CONSTANT__* p_c_grid_block_cluster_blockid_to_gm10_gn10)
|
||||
const void CONSTANT* p_desc_tuple)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(256, 256, 28, 28));
|
||||
@@ -316,12 +303,12 @@ extern "C" __global__ void
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>;
|
||||
|
||||
using GridwiseContraction =
|
||||
GridwiseDynamicContraction_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
GridwiseDynamicContractionDlops_A_GK0_GM0_GM1_GK1_B_GK0_GN0_GN1_GK1_C_GM0_GM1_GN0_GN1<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
InMemoryDataOperation::Set,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_GK0_GM0_GM1_GK1,
|
||||
BGridDesc_GK0_GN0_GN1_GK1,
|
||||
CGridDesc_GM0_GM1_GN0_GN1,
|
||||
@@ -331,10 +318,8 @@ extern "C" __global__ void
|
||||
BM1PerThreadBM11,
|
||||
BN1PerThreadBN11,
|
||||
BK0PerThread,
|
||||
BM10BN10ThreadClusterBM100,
|
||||
BM10BN10ThreadClusterBN100,
|
||||
BM10BN10ThreadClusterBM101,
|
||||
BM10BN10ThreadClusterBN101,
|
||||
BM10BN10ThreadClusterBM10Xs,
|
||||
BM10BN10ThreadClusterBN10Xs,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
@@ -371,18 +356,23 @@ extern "C" __global__ void
|
||||
decltype(GridwiseContraction::MakeCGridBlockCluster_BlockId_To_GM10_GN10(
|
||||
c_grid_desc_gm0_gm1_gn0_gn1));
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 =
|
||||
*reinterpret_cast<const AGridDesc_GK0_GM0_GM10_GM11_GK1*>(
|
||||
(const void*)p_a_grid_desc_gk0_gm0_gm10_gm11_gk1);
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 =
|
||||
*reinterpret_cast<const BGridDesc_GK0_GN0_GN10_GN11_GK1*>(
|
||||
(const void*)p_b_grid_desc_gk0_gn0_gn10_gn11_gk1);
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 =
|
||||
*reinterpret_cast<const CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1*>(
|
||||
(const void*)p_c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1);
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 =
|
||||
*reinterpret_cast<const CGridBlockCluster_BlockId_To_GM10_GN10*>(
|
||||
(const void*)p_c_grid_block_cluster_blockid_to_gm10_gn10);
|
||||
using DescTuple = decltype(make_tuple(AGridDesc_GK0_GM0_GM10_GM11_GK1{},
|
||||
BGridDesc_GK0_GN0_GN10_GN11_GK1{},
|
||||
CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{},
|
||||
CGridBlockCluster_BlockId_To_GM10_GN10{}));
|
||||
|
||||
const auto desc_tuple = *reinterpret_cast<const DescTuple*>(
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wold-style-cast"
|
||||
// TODO: how to cast?
|
||||
(const void*)p_desc_tuple
|
||||
#pragma clang diagnostic pop
|
||||
);
|
||||
|
||||
const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0];
|
||||
const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1];
|
||||
const auto c_grid_desc_gm10_bm0_bm1_gn10_bn0_bn1 = desc_tuple[I2];
|
||||
const auto c_grid_block_cluster_blockid_to_gm10_gn10 = desc_tuple[I3];
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
Reference in New Issue
Block a user