mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
DL GEMM fp32/fp16/int8 (#41)
* add threadwise copy the copy a tensor in one copy, added kpack to DL GEMM * add kpack into fwd v4r5 nchw fp32
This commit is contained in:
@@ -0,0 +1,292 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP
|
||||
#define CK_DRIVER_DYNAMIC_CONTRACTION_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_contraction_v1r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGKGM0GM1GridDesc,
|
||||
typename BGKGN0GN1GridDesc,
|
||||
typename CGM0GM1GN0GN1GridDesc,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float
|
||||
driver_dynamic_contraction_v1r2(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AGKGM0GM1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
|
||||
const BGKGN0GN1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseContraction = GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2<
|
||||
BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AGKGM0GM1GridDesc,
|
||||
BGKGN0GN1GridDesc,
|
||||
CGM0GM1GN0GN1GridDesc,
|
||||
GM1PerBlockGM11,
|
||||
GN1PerBlockGN11,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseContraction::CheckValidity(
|
||||
a_gk0_gm0_gm1_gk1_grid_desc, b_gk0_gn0_gn1_gk1_grid_desc, c_gm0_gm1_gn0_gn1_grid_desc))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseDynamicContraction_km_kn0n1_mn0n1_v1r1 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc =
|
||||
GridwiseContraction::MakeAGK0GM0GM10GM11GK1GridDescriptor(a_gk0_gm0_gm1_gk1_grid_desc);
|
||||
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc =
|
||||
GridwiseContraction::MakeBGK0GN0GN10GN11GK1GridDescriptor(b_gk0_gn0_gn1_gk1_grid_desc);
|
||||
|
||||
using AGK0GM0GM10GM11GK1GridDesc = decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc);
|
||||
using BGK0GN0GN10GN11GK1GridDesc = decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc);
|
||||
|
||||
// c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc =
|
||||
GridwiseContraction::MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc = decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc);
|
||||
|
||||
// c_blockid_to_gm10_gn10_block_cluster_adaptor
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor =
|
||||
GridwiseContraction::MakeCBlockIdToGM10GN10BlockClusterAdaptor(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseContraction::CalculateGridSize(c_gm0_gm1_gn0_gn1_grid_desc);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseContraction::CalculateHasMainKBlockLoop(GK0);
|
||||
|
||||
const bool has_double_tail_k_block_loop =
|
||||
GridwiseContraction::CalculateHasDoubleTailKBlockLoop(GK0);
|
||||
|
||||
{
|
||||
std::cout << "a_gk0_gm0_gm10_gm11_gk1_grid_desc{"
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I3) << ", "
|
||||
<< a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_gk0_gn0_gn10_gn11_gk1_grid_desc{"
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I3) << ", "
|
||||
<< b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetLength(I4) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc{ "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I0) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I1) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I2) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I3) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I4) << ", "
|
||||
<< c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
|
||||
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
|
||||
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
|
||||
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_dynamic_contraction_v1r1<
|
||||
GridwiseContraction,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AGK0GM0GM10GM11GK1GridDesc>,
|
||||
remove_reference_t<BGK0GN0GN10GN11GK1GridDesc>,
|
||||
remove_reference_t<CGM10BM0BM1GN10BN0BN1GridDesc>,
|
||||
remove_reference_t<CBlockIdToGM10GN10BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
416
composable_kernel/include/driver/driver_dynamic_gemm_v1r3.hpp
Normal file
416
composable_kernel/include/driver/driver_dynamic_gemm_v1r3.hpp
Normal file
@@ -0,0 +1,416 @@
|
||||
#ifndef CK_DRIVER_DYNAMIC_GEMM_v1r3
|
||||
#define CK_DRIVER_DYNAMIC_GEMM_v1r3
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_dynamic_gemm_v1r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThread,
|
||||
index_t N1PerThread,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM10,
|
||||
index_t M1N1ThreadClusterN10,
|
||||
index_t M1N1ThreadClusterM11,
|
||||
index_t M1N1ThreadClusterN11,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
__host__ float driver_dynamic_gemm_v1r3(const FloatAB* p_a_grid,
|
||||
const FloatAB* p_b_grid,
|
||||
FloatC* p_c_grid,
|
||||
const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks,
|
||||
index_t nrepeat)
|
||||
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
// GEMM
|
||||
using GridwiseGemm =
|
||||
GridwiseDynamicGemm_km_kn_mn_v1r3<BlockSize,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
CGlobalMemoryDataOperation,
|
||||
AK0MK1GridDesc,
|
||||
BK0NK1GridDesc,
|
||||
CMNGridDesc,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
M1PerThread,
|
||||
N1PerThread,
|
||||
KPerThread,
|
||||
M1N1ThreadClusterM10,
|
||||
M1N1ThreadClusterN10,
|
||||
M1N1ThreadClusterM11,
|
||||
M1N1ThreadClusterN11,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
AGridIteratorHacks,
|
||||
BGridIteratorHacks,
|
||||
CGridIteratorHacks,
|
||||
AGridMoveSliceWindowIteratorHacks,
|
||||
BGridMoveSliceWindowIteratorHacks>;
|
||||
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc))
|
||||
{
|
||||
throw std::runtime_error("wrong! GridwiseDynamicGemm_km_kn_mn_v1r3 has invalid setting");
|
||||
}
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_k0_m_k1_grid_desc);
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_k0_n_k1_grid_desc);
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(a_k0_m0_m1_k1_grid_desc);
|
||||
using BK0N0N1K1GridDesc = decltype(b_k0_n0_n1_k1_grid_desc);
|
||||
|
||||
// c_m0_m10_m11_n0_n10_n11_grid_desc
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc);
|
||||
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
|
||||
// c_blockid_to_m0_n0_block_cluster_adaptor
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_m_n_grid_desc);
|
||||
|
||||
using CBlockIdToM0N0BlockClusterAdaptor = decltype(c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(M, N);
|
||||
|
||||
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
|
||||
const bool has_double_tail_k_block_loop = GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
|
||||
|
||||
{
|
||||
std::cout << "a_k0_m0_m1_k1_grid_desc{" << a_k0_m0_m1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< a_k0_m0_m1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "b_k0_n0_n1_k1_grid_desc{" << b_k0_n0_n1_k1_grid_desc.GetLength(I0) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I1) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I2) << ", "
|
||||
<< b_k0_n0_n1_k1_grid_desc.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "c_m0_m10_m11_n0_n10_n11_grid_desc{ "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I0) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I1) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I2) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I3) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I4) << ", "
|
||||
<< c_m0_m10_m11_n0_n10_n11_grid_desc.GetLength(I5) << "}" << std::endl;
|
||||
}
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
DeviceMem a_k0_m0_m1_k1_grid_desc_dev_buf(sizeof(AK0M0M1K1GridDesc));
|
||||
DeviceMem b_k0_n0_n1_k1_grid_desc_dev_buf(sizeof(BK0N0N1K1GridDesc));
|
||||
DeviceMem c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf(sizeof(CM0M10M11N0N10N11GridDesc));
|
||||
DeviceMem c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf(
|
||||
sizeof(CBlockIdToM0N0BlockClusterAdaptor));
|
||||
|
||||
a_k0_m0_m1_k1_grid_desc_dev_buf.ToDevice(&a_k0_m0_m1_k1_grid_desc);
|
||||
b_k0_n0_n1_k1_grid_desc_dev_buf.ToDevice(&b_k0_n0_n1_k1_grid_desc);
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.ToDevice(&c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.ToDevice(
|
||||
&c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
true,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel =
|
||||
kernel_dynamic_gemm_v1r3<GridwiseGemm,
|
||||
FloatAB,
|
||||
FloatC,
|
||||
remove_reference_t<AK0M0M1K1GridDesc>,
|
||||
remove_reference_t<BK0N0N1K1GridDesc>,
|
||||
remove_reference_t<CM0M10M11N0N10N11GridDesc>,
|
||||
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>,
|
||||
false,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
(void __CONSTANT__*)a_k0_m0_m1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)b_k0_n0_n1_k1_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_m0_m10_m11_n0_n10_n11_grid_desc_dev_buf.GetDeviceBuffer(),
|
||||
(void __CONSTANT__*)c_blockid_to_m0_n0_block_cluster_adaptor_dev_buf.GetDeviceBuffer());
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -12,7 +12,7 @@ namespace ck {
|
||||
// C: out
|
||||
// GemmM = N * Ho * Wo
|
||||
// GemmN = K
|
||||
// GemmK = C * Y * X
|
||||
// GemmK = Y * X * C
|
||||
template <typename... In,
|
||||
typename... Wei,
|
||||
typename... Out,
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
#ifndef CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
|
||||
#define CK_TRANSFORM_FORWARD_CONVOLUTION_INTO_CONTRACTION_V4R5R2_NCHW_KCYX_NKHW_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// GemmM = K
|
||||
// GemmN = N * Ho * Wo
|
||||
// GemmK = C * Y * X
|
||||
template <typename... Wei,
|
||||
typename... In,
|
||||
typename... Out,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads,
|
||||
index_t N0Value,
|
||||
index_t C0Value>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(
|
||||
const DynamicTensorDescriptor<Wei...>& wei_k_c_y_x_grid_desc,
|
||||
const DynamicTensorDescriptor<In...>& in_n_c_hi_wi_grid_desc,
|
||||
const DynamicTensorDescriptor<Out...>& out_n_k_ho_wo_grid_desc,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
Number<N0Value>,
|
||||
Number<C0Value>)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0);
|
||||
const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1);
|
||||
const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1);
|
||||
|
||||
const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2);
|
||||
const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2);
|
||||
const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3);
|
||||
|
||||
const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2);
|
||||
const auto X = wei_k_c_y_x_grid_desc.GetLength(I3);
|
||||
|
||||
const auto ConvStrideH = conv_strides[I0];
|
||||
const auto ConvStrideW = conv_strides[I1];
|
||||
|
||||
const auto ConvDilationH = conv_dilations[I0];
|
||||
const auto ConvDilationW = conv_dilations[I1];
|
||||
|
||||
const auto InLeftPadH = in_left_pads[I0];
|
||||
const auto InLeftPadW = in_left_pads[I1];
|
||||
|
||||
const auto InRightPadH = in_right_pads[I0];
|
||||
const auto InRightPadW = in_right_pads[I1];
|
||||
|
||||
constexpr auto N0 = Number<N0Value>{};
|
||||
constexpr auto C0 = Number<C0Value>{};
|
||||
|
||||
const auto N1 = N / N0;
|
||||
const auto C1 = C / C0;
|
||||
|
||||
// weight tensor
|
||||
const auto wei_gk0_gm0_gm1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(K, C * Y * X)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_unmerge_transform(make_tuple(C0, C1 * Y * X))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<3, 0>{}));
|
||||
|
||||
// input tensor
|
||||
const auto in_n_c_hip_wip_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hi_wi_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pass_through_transform(C),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n0_n1_c0_c1_y_ho_x_wo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n_c_hip_wip_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_unmerge_transform(make_tuple(C0, C1)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
const auto in_gk0_gn0_gn1_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
in_n0_n1_c0_c1_y_ho_x_wo_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(C1, Y, X)),
|
||||
make_pass_through_transform(N0),
|
||||
make_merge_transform(make_tuple(N1, Ho, Wo)),
|
||||
make_pass_through_transform(C0)),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<0>{}, Sequence<1, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// output tensor
|
||||
const auto out_n_k_howo_grid_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(make_tuple(N, K, Ho * Wo));
|
||||
|
||||
const auto out_n0_n1_1_k_howo_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
out_n_k_howo_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(Number<N0>{}, N1)),
|
||||
make_unmerge_transform(make_tuple(I1, K)),
|
||||
make_pass_through_transform(Ho * Wo)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
out_n0_n1_1_k_howo_grid_desc,
|
||||
make_tuple(make_pass_through_transform(I1),
|
||||
make_pass_through_transform(K),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_merge_transform_v2_magic_division(make_tuple(N1, Ho * Wo))),
|
||||
make_tuple(Sequence<2>{}, Sequence<3>{}, Sequence<0>{}, Sequence<1, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
return make_tuple(
|
||||
wei_gk0_gm0_gm1_gk1_grid_desc, in_gk0_gn0_gn1_gk1_grid_desc, out_gm0_gm1_gn0_gn1_grid_desc);
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,158 @@
|
||||
#ifndef CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#define CK_BLOCKWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "cluster_descriptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// this version does following things to avoid scratch memory issue
|
||||
// 1. Use StaticallyIndexedArray instead of C array for thread buffer
|
||||
// 2. ThreadwiseDynamicTensorSliceTransfer_v3 does not keep reference to tensor descriptor
|
||||
// 3. ThreadwiseDynamicTensorSliceTransfer_v3::Run() does not construct new tensor coordinate
|
||||
template <index_t BlockSize,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
typename BlockSliceLengths,
|
||||
typename ThreadSliceLengths,
|
||||
typename ThreadClusterLengths,
|
||||
typename ThreadClusterArrangeOrder,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename DstVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename DstVectorTensorContiguousDimOrder,
|
||||
bool ThreadTransferSrcResetCoordinateAfterRun,
|
||||
bool ThreadTransferDstResetCoordinateAfterRun>
|
||||
struct BlockwiseDynamicTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseDynamicTensorSliceTransfer_v4r1(
|
||||
const SrcDesc& src_desc,
|
||||
const Index& src_block_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_block_slice_origin)
|
||||
: threadwise_transfer_(
|
||||
src_desc, make_zero_multi_index<nDim>(), dst_desc, make_zero_multi_index<nDim>())
|
||||
|
||||
{
|
||||
static_assert(nDim == remove_reference_t<remove_cv_t<SrcDesc>>::GetNumOfDimension() &&
|
||||
nDim == remove_reference_t<remove_cv_t<DstDesc>>::GetNumOfDimension() &&
|
||||
nDim == BlockSliceLengths::Size() && nDim == ThreadSliceLengths::Size() &&
|
||||
nDim == ThreadClusterLengths::Size() &&
|
||||
nDim == ThreadClusterArrangeOrder::Size() &&
|
||||
nDim == SrcDimAccessOrder::Size() && nDim == DstDimAccessOrder::Size(),
|
||||
"wrong! nDim not consistent");
|
||||
|
||||
static_assert(
|
||||
is_same<BlockSliceLengths, decltype(ThreadSliceLengths{} * ThreadClusterLengths{})>{},
|
||||
"wrong! threads should be mapped to cover entire slicing window");
|
||||
|
||||
static_assert(BlockSize >= thread_cluster_desc_.GetElementSize(),
|
||||
"wrong! BlockSize too small");
|
||||
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(
|
||||
make_multi_index(get_thread_local_1d_id()));
|
||||
|
||||
const auto thread_data_idx_begin = thread_cluster_idx * ThreadSliceLengths{};
|
||||
|
||||
threadwise_transfer_.SetSrcSliceOrigin(src_desc,
|
||||
src_block_slice_origin + thread_data_idx_begin);
|
||||
threadwise_transfer_.SetDstSliceOrigin(dst_desc,
|
||||
dst_block_slice_origin + thread_data_idx_begin);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.RunWrite(dst_desc, dst_buf);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(src_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
// SrcMoveSliceWindowIteratorHack to control index calculation move slice window
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& step,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveSrcSliceWindow(
|
||||
src_desc, step, src_move_slice_window_iterator_hack);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc, const Index& step)
|
||||
{
|
||||
if(BlockSize == thread_cluster_desc_.GetElementSize() or
|
||||
get_thread_local_1d_id() < thread_cluster_desc_.GetElementSize())
|
||||
{
|
||||
threadwise_transfer_.MoveDstSliceWindow(dst_desc, step);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto thread_cluster_desc_ =
|
||||
make_cluster_descriptor_v2(ThreadClusterLengths{}, ThreadClusterArrangeOrder{});
|
||||
|
||||
using ThreadwiseTransfer =
|
||||
ThreadwiseDynamicTensorSliceTransfer_v3r1<ThreadSliceLengths,
|
||||
DstInMemOp,
|
||||
SrcData,
|
||||
DstData,
|
||||
SrcDesc,
|
||||
DstDesc,
|
||||
SrcDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
SrcVectorTensorLengths,
|
||||
DstVectorTensorLengths,
|
||||
SrcVectorTensorContiguousDimOrder,
|
||||
DstVectorTensorContiguousDimOrder,
|
||||
ThreadTransferSrcResetCoordinateAfterRun,
|
||||
ThreadTransferDstResetCoordinateAfterRun>;
|
||||
|
||||
ThreadwiseTransfer threadwise_transfer_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,398 @@
|
||||
#ifndef CK_BLOCKWISE_GEMM_V2R3_HPP
|
||||
#define CK_BLOCKWISE_GEMM_V2R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_adaptor.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_gemm_v2.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// Assume:
|
||||
// 1. A:
|
||||
// 1. AK0MK1BlockDesc is known at compile-time
|
||||
// 2. ABlockBuffer is DynamicBuffer
|
||||
// 2. B:
|
||||
// 1. BK0NK1BlockDesc is known at compile-time
|
||||
// 2. BBlockBuffer is DynamicBuffer
|
||||
// 3. C:
|
||||
// 1. CM0M1N0N1ThreadDesc is known at compile-time
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AK0MK1BlockDesc,
|
||||
typename BK0NK1BlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename std::enable_if<AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
using BIndex = MultiIndex<3>;
|
||||
using CIndex = MultiIndex<4>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr index_t K0 = AK0MK1BlockDesc{}.GetLength(I0);
|
||||
static constexpr index_t K1 = AK0MK1BlockDesc{}.GetLength(I2);
|
||||
static constexpr index_t M = AK0MK1BlockDesc{}.GetLength(I1);
|
||||
static constexpr index_t N = BK0NK1BlockDesc{}.GetLength(I1);
|
||||
|
||||
static constexpr index_t M100 = M1N1ThreadClusterM100;
|
||||
static constexpr index_t N100 = M1N1ThreadClusterN100;
|
||||
|
||||
static constexpr index_t M101 = M1N1ThreadClusterM101;
|
||||
static constexpr index_t N101 = M1N1ThreadClusterN101;
|
||||
|
||||
static constexpr index_t M11 = M1PerThreadM11;
|
||||
static constexpr index_t N11 = N1PerThreadN11;
|
||||
|
||||
static constexpr index_t M1 = M1N1ThreadClusterM100 * M1N1ThreadClusterM101 * M1PerThreadM11;
|
||||
static constexpr index_t N1 = M1N1ThreadClusterN100 * M1N1ThreadClusterN101 * N1PerThreadN11;
|
||||
|
||||
static constexpr index_t M0 = M / M1;
|
||||
static constexpr index_t N0 = N / N1;
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAK0M0M1K1BlockDescriptor(const AK0MK1BlockDesc& a_k0_m_k1_block_desc)
|
||||
{
|
||||
const auto a_k0_m0_m1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k0_m_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<M0>{}, Number<M1>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_k0_m0_m1_k1_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBK0N0N1K1BlockDescriptor(const BK0NK1BlockDesc& b_k0_n_k1_block_desc)
|
||||
{
|
||||
const auto b_k0_n0_n1_k1_block_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k0_n_k1_block_desc,
|
||||
make_tuple(make_pass_through_transform(Number<K0>{}),
|
||||
make_unmerge_transform(make_tuple(Number<N0>{}, Number<N1>{})),
|
||||
make_pass_through_transform(Number<K1>{})),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_k0_n0_n1_k1_block_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCM0M100M101M11N0N100N101N11ToMNBlockAdaptor()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M, N]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_unmerge_transform(make_tuple(
|
||||
Number<M0>{}, Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_unmerge_transform(make_tuple(
|
||||
Number<N0>{}, Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2, 3>{}, Sequence<4, 5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m_n_block_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor()
|
||||
{
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// lower: [M0, M1, N0, N1]
|
||||
constexpr auto c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor =
|
||||
make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_pass_through_transform(Number<M0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<M100>{}, Number<M101>{}, Number<M11>{})),
|
||||
make_pass_through_transform(Number<N0>{}),
|
||||
make_unmerge_transform(
|
||||
make_tuple(Number<N100>{}, Number<N101>{}, Number<N11>{}))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{}, Sequence<5, 6, 7>{}));
|
||||
|
||||
return c_m0_m100_m101_m11_n0_n100_n101_n11_to_m0_m1_n0_n1_block_adaptor;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto GetCM0M1N0N1ThreadTensorLengths()
|
||||
{
|
||||
return Sequence<M0, M11, N0, N11>{};
|
||||
}
|
||||
|
||||
static constexpr auto a_k0_m0_m1_k1_block_desc_ =
|
||||
MakeAK0M0M1K1BlockDescriptor(AK0MK1BlockDesc{});
|
||||
static constexpr auto b_k0_n0_n1_k1_block_desc_ =
|
||||
MakeBK0N0N1K1BlockDescriptor(BK0NK1BlockDesc{});
|
||||
|
||||
public:
|
||||
__device__ BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2()
|
||||
: c_thread_origin_data_idx_{CalculateCM0M1N0N1ThreadOriginOnBlock(
|
||||
get_thread_local_1d_id())},
|
||||
a_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I0], c_thread_origin_data_idx_[I1], 0)},
|
||||
b_thread_copy_{
|
||||
make_tuple(0, c_thread_origin_data_idx_[I2], c_thread_origin_data_idx_[I3], 0)}
|
||||
{
|
||||
static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() &&
|
||||
BK0NK1BlockDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(BlockSize == M101 * M100 * N101 * N100,
|
||||
"wrong! blocksize and cluster size not consistent");
|
||||
|
||||
static_assert(M % M1 == 0 && N % N1 == 0, "wrong!");
|
||||
|
||||
static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0),
|
||||
"wrong! K dimension not consistent");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2, "wrong");
|
||||
}
|
||||
|
||||
__device__ static CIndex CalculateCM0M1N0N1ThreadOriginOnBlock(index_t thread_id)
|
||||
{
|
||||
// lower: [M0, M1, N0, N1]
|
||||
// upper: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
constexpr auto adaptor0 = MakeCM0M100M101M11N0N100N101N11ToM0M1N0N1BlockAdaptor();
|
||||
|
||||
// lower: [M0, M100, M101, M11, N0, N100, N101, N11]
|
||||
// upper: [Tid, M0, M11, N0, N11]
|
||||
constexpr auto adaptor1 = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(M100, N100, M101, N101)),
|
||||
make_pass_through_transform(M0),
|
||||
make_pass_through_transform(M11),
|
||||
make_pass_through_transform(N0),
|
||||
make_pass_through_transform(N11)),
|
||||
make_tuple(
|
||||
Sequence<1, 5, 2, 6>{}, Sequence<0>{}, Sequence<3>{}, Sequence<4>{}, Sequence<7>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
constexpr auto adaptor = chain_tensor_adaptors(adaptor0, adaptor1);
|
||||
|
||||
return adaptor.CalculateBottomIndex(make_multi_index(get_thread_local_1d_id(), 0, 0, 0, 0));
|
||||
}
|
||||
|
||||
template <typename CM0M1N0N1ThreadDesc,
|
||||
typename ABlockBuffer,
|
||||
typename BBlockBuffer,
|
||||
typename CThreadBuffer>
|
||||
__device__ void Run(const CM0M1N0N1ThreadDesc& c_m0_m1_n0_n1_thread_desc,
|
||||
const ABlockBuffer& a_block_buf,
|
||||
const BBlockBuffer& b_block_buf,
|
||||
CThreadBuffer& c_thread_buf) const
|
||||
{
|
||||
static_assert(CM0M1N0N1ThreadDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: remove this restriction
|
||||
static_assert(M0 == 2 && N0 == 2 && CM0M1N0N1ThreadDesc{}.GetLength(I0) == M0 &&
|
||||
CM0M1N0N1ThreadDesc{}.GetLength(I2) == N0,
|
||||
"wrong");
|
||||
|
||||
auto a_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatA>(
|
||||
a_k0_m0_m1_k1_thread_desc_.GetElementSpaceSize());
|
||||
auto b_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatB>(
|
||||
b_k0_n0_n1_k1_thread_desc_.GetElementSpaceSize());
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1<FloatA,
|
||||
FloatB,
|
||||
FloatC,
|
||||
decltype(a_k0_m0_m1_k1_thread_desc_),
|
||||
decltype(b_k0_n0_n1_k1_thread_desc_),
|
||||
CM0M1N0N1ThreadDesc,
|
||||
Sequence<KPerThread, K1>,
|
||||
Sequence<1, M1PerThreadM11>,
|
||||
Sequence<1, N1PerThreadN11>>{};
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_k0_m0_m1_k1_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_k0_n0_n1_k1_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_k0_n0_n1_k1_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_k0_m0_m1_k1_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
|
||||
// loop over rest of k
|
||||
static_for<KPerThread, K0, KPerThread>{}([&](auto k) {
|
||||
// read A_sub_0
|
||||
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
|
||||
make_tuple(k, I0, I0, I0),
|
||||
a_block_buf,
|
||||
a_k0_m0_m1_k1_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
|
||||
make_tuple(k, I0, I0, I0),
|
||||
b_block_buf,
|
||||
b_k0_n0_n1_k1_thread_desc_,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy_.Run(b_k0_n0_n1_k1_block_desc_,
|
||||
make_tuple(k, I1, I0, I0),
|
||||
b_block_buf,
|
||||
b_k0_n0_n1_k1_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf);
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy_.Run(a_k0_m0_m1_k1_block_desc_,
|
||||
make_tuple(k, I1, I0, I0),
|
||||
a_block_buf,
|
||||
a_k0_m0_m1_k1_thread_desc_,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
a_thread_buf);
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0));
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I0, I0, I1, I0));
|
||||
});
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I0, I0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(a_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
b_thread_buf,
|
||||
make_tuple(I0, I1, I0, I0),
|
||||
c_thread_buf,
|
||||
make_tuple(I1, I0, I1, I0));
|
||||
}
|
||||
|
||||
private:
|
||||
// A[K0, M0, M1, K1]
|
||||
static constexpr auto a_k0_m0_m1_k1_thread_desc_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<M0>{}, Number<M1PerThreadM11>{}, Number<K1>{}));
|
||||
|
||||
// B[K0, N0, N1, K1]
|
||||
static constexpr auto b_k0_n0_n1_k1_thread_desc_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(Number<KPerThread>{}, Number<N0>{}, Number<N1PerThreadN11>{}, Number<K1>{}));
|
||||
|
||||
using AThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
|
||||
FloatA,
|
||||
FloatA,
|
||||
decltype(a_k0_m0_m1_k1_block_desc_),
|
||||
decltype(a_k0_m0_m1_k1_thread_desc_),
|
||||
Sequence<KPerThread, 1, M1PerThreadM11, K1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, M1PerThreadM11, K1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
using BThreadCopy = ThreadwiseDynamicTensorSliceTransfer_v4r1<
|
||||
FloatB,
|
||||
FloatB,
|
||||
decltype(b_k0_n0_n1_k1_block_desc_),
|
||||
decltype(b_k0_n0_n1_k1_thread_desc_),
|
||||
Sequence<KPerThread, 1, N1PerThreadN11, K1>, // SliceLengths
|
||||
Sequence<0, 1, 2, 3>, // DimAccessOrder
|
||||
Sequence<1, 1, N1PerThreadN11, K1>, // SrcVectorTensorLengths
|
||||
Sequence<0, 1, 2, 3>>; // SrcVectorTensorContiguousDimOrder
|
||||
|
||||
CIndex c_thread_origin_data_idx_;
|
||||
|
||||
AThreadCopy a_thread_copy_;
|
||||
BThreadCopy b_thread_copy_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,675 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_CONTRACTION_V1R2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r2.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename GridwiseContraction,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AGK0GM0GM10GM11GK1GridDesc,
|
||||
typename BGK0GN0GN10GN11GK1GridDesc,
|
||||
typename CGM10BM0BM1GN10BN0BN1GridDesc,
|
||||
typename CBlockIdToGM10GN10BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_contraction_v1r2(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AGK0GM0GM10GM11GK1GridDesc a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
const BGK0GN0GN10GN11GK1GridDesc b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
const CGM10BM0BM1GN10BN0BN1GridDesc c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const CBlockIdToGM10GN10BlockClusterAdaptor
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseContraction::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseContraction::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AGK0GM0GM1GK1GridDesc,
|
||||
typename BGK0GN0GN1GK1GridDesc,
|
||||
typename CGM0GM1GN0GN1GridDesc,
|
||||
index_t GM1PerBlockGM11,
|
||||
index_t GN1PerBlockGN11,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicContraction_k0m0m1k1_k0n0n1k1_m0m1n0n1_v1r2
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// GM0 and GN0 need to known at compile-time
|
||||
static constexpr auto GM0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I0);
|
||||
static constexpr auto GN0 = CGM0GM1GN0GN1GridDesc{}.GetLength(I2);
|
||||
static constexpr auto GK1 = AGK0GM0GM1GK1GridDesc{}.GetLength(I3);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// lds max alignment
|
||||
// TODO: part of them should be moved into blockwise-gemm
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = GK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc,
|
||||
const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc,
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<remove_cv_t<decltype(GM0)>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<decltype(GN0)>>::value,
|
||||
"wrong! GM0 and GN0 need to be known at compile-time");
|
||||
|
||||
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
|
||||
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
|
||||
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return ((GM0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I0) &&
|
||||
GM1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1) &&
|
||||
GN0 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I2) &&
|
||||
GN1 == c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3) &&
|
||||
GM0 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I1) &&
|
||||
GM1 == a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2) &&
|
||||
GN0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I1) &&
|
||||
GN1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2) &&
|
||||
GK0 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0) &&
|
||||
GK1 == b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I3)) &&
|
||||
(GM1 % GM1PerBlockGM11 == 0 && GN1 % GN1PerBlockGN11 == 0 && GK0 % KPerBlock == 0));
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t
|
||||
CalculateGridSize(const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t GM11 = GM1PerBlockGM11;
|
||||
constexpr index_t GN11 = GN1PerBlockGN11;
|
||||
|
||||
const index_t GM10 = GM1 / GM11;
|
||||
const index_t GN10 = GN1 / GN11;
|
||||
|
||||
const index_t grid_size = GM10 * GN10;
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (GK0 + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t GK0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (GK0 / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAGK0GM0GM10GM11GK1GridDescriptor(const AGK0GM0GM1GK1GridDesc& a_gk0_gm0_gm1_gk1_grid_desc)
|
||||
{
|
||||
const auto GK0 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I0);
|
||||
const auto GM1 = a_gk0_gm0_gm1_gk1_grid_desc.GetLength(I2);
|
||||
|
||||
const auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
const auto GM10 = GM1 / GM11;
|
||||
|
||||
const auto a_gk0_gm0_gm10_gm11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
a_gk0_gm0_gm1_gk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return a_gk0_gm0_gm10_gm11_gk1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBGK0GN0GN10GN11GK1GridDescriptor(const BGK0GN0GN1GK1GridDesc& b_gk0_gn0_gn1_gk1_grid_desc)
|
||||
{
|
||||
const auto GK0 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I0);
|
||||
const auto GN1 = b_gk0_gn0_gn1_gk1_grid_desc.GetLength(I2);
|
||||
|
||||
const auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto b_gk0_gn0_gn10_gn11_gk1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
b_gk0_gn0_gn1_gk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GK0),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11)),
|
||||
make_pass_through_transform(GK1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4>{}));
|
||||
|
||||
return b_gk0_gn0_gn10_gn11_gk1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
constexpr auto BM = GM0 * GM11;
|
||||
constexpr auto BN = GN0 * GN11;
|
||||
|
||||
constexpr auto BM1 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
constexpr auto BN1 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
|
||||
constexpr auto BM0 = BM / BM1;
|
||||
constexpr auto BN0 = BN / BN1;
|
||||
|
||||
const auto c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm0_gm1_gn0_gn1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM0),
|
||||
make_unmerge_transform(make_tuple(GM10, GM11)),
|
||||
make_pass_through_transform(GN0),
|
||||
make_unmerge_transform(make_tuple(GN10, GN11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
const auto c_gm10_bm_gn10_bn_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm0_gm10_gm11_gn0_gn10_gn11_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_merge_transform(make_tuple(GM0, GM11)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_merge_transform(make_tuple(GN0, GN11))),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}, Sequence<4>{}, Sequence<3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_gm10_bm_gn10_bn_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GM10),
|
||||
make_unmerge_transform(make_tuple(BM0, BM1)),
|
||||
make_pass_through_transform(GN10),
|
||||
make_unmerge_transform(make_tuple(BN0, BN1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}));
|
||||
|
||||
return c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto MakeCBlockIdToGM10GN10BlockClusterAdaptor(
|
||||
const CGM0GM1GN0GN1GridDesc& c_gm0_gm1_gn0_gn1_grid_desc)
|
||||
{
|
||||
const auto GM1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I1);
|
||||
const auto GN1 = c_gm0_gm1_gn0_gn1_grid_desc.GetLength(I3);
|
||||
|
||||
constexpr auto GM11 = Number<GM1PerBlockGM11>{};
|
||||
constexpr auto GN11 = Number<GN1PerBlockGN11>{};
|
||||
|
||||
const auto GM10 = GM1 / GM11;
|
||||
const auto GN10 = GN1 / GN11;
|
||||
|
||||
const auto c_blockid_to_gm10_gn10_block_cluster_adaptor = make_single_stage_tensor_adaptor(
|
||||
make_tuple(make_merge_transform(make_tuple(GM10, GN10))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_gm10_gn10_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AGK0GM0GM10GM11GK1GridDesc =
|
||||
decltype(MakeAGK0GM0GM10GM11GK1GridDescriptor(AGK0GM0GM1GK1GridDesc{}));
|
||||
using BGK0GN0GN10GN11GK1GridDesc =
|
||||
decltype(MakeBGK0GN0GN10GN11GK1GridDescriptor(BGK0GN0GN1GK1GridDesc{}));
|
||||
using CGM10BM0BM1GN10BN0BN1GridDesc =
|
||||
decltype(MakeCGM10BM0BM1GN10BN0BN1GridDescriptor(CGM0GM1GN0GN1GridDesc{}));
|
||||
using CBlockIdToGM10GN10BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToGM10GN10BlockClusterAdaptor(CGM0GM1GN0GN1GridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AGK0GM0GM10GM11GK1GridDesc& a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
const BGK0GN0GN10GN11GK1GridDesc& b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
const CGM10BM0BM1GN10BN0BN1GridDesc& c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
const CBlockIdToGM10GN10BlockClusterAdaptor& c_blockid_to_gm10_gn10_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_grid, a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_grid, b_gk0_gn0_gn10_gn11_gk1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_grid, c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc.GetElementSpaceSize());
|
||||
|
||||
const auto GK0 = a_gk0_gm0_gm10_gm11_gk1_grid_desc.GetLength(I0);
|
||||
|
||||
// divide block work by [GM10, GN10]
|
||||
const auto c_gm10_gn10_block_cluster_idx =
|
||||
c_blockid_to_gm10_gn10_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t igm10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I0]);
|
||||
const index_t ign10 = __builtin_amdgcn_readfirstlane(c_gm10_gn10_block_cluster_idx[I1]);
|
||||
|
||||
// lds max alignment
|
||||
// TODO: part of them should be moved into blockwise-gemm
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = GK1;
|
||||
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk0_gm0_gm10_gm11_gk1_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0, I1, Number<GM1PerBlockGM11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk0_gn0_gn10_gn11_gk1_block_desc =
|
||||
make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0, I1, Number<GN1PerBlockGN11>{}, GK1),
|
||||
max_lds_align);
|
||||
|
||||
// A matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_gk0_bm_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GM0 * Number<GM1PerBlockGM11>{}, GK1), max_lds_align);
|
||||
|
||||
// B matrix in LDS memory for blockwise GEMM
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_gk0_bn_gk1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, GN0 * Number<GN1PerBlockGN11>{}, GK1), max_lds_align);
|
||||
|
||||
static_assert(a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize() ==
|
||||
a_gk0_bm_gk1_block_desc.GetElementSpaceSize() &&
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize() ==
|
||||
b_gk0_bn_gk1_block_desc.GetElementSpaceSize(),
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, GM0, 1, GM1PerBlockGM11, GK1.value>,
|
||||
ABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_gk0_gm0_gm10_gm11_gk1_grid_desc),
|
||||
decltype(a_gk0_gm0_gm10_gm11_gk1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
ABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // SrcVectorTensorLengths
|
||||
ABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1, // DstVectorTensorLengths
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
make_multi_index(0, 0, igm10, 0, 0),
|
||||
a_gk0_gm0_gm10_gm11_gk1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, GN0, 1, GN1PerBlockGN11, GK1.value>,
|
||||
BBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_gk0_gn0_gn10_gn11_gk1_grid_desc),
|
||||
decltype(b_gk0_gn0_gn10_gn11_gk1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3, 4>,
|
||||
BBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // SrcVectorTensorLengths
|
||||
BBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1, // DstVectorTensorLengths
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
make_multi_index(0, 0, ign10, 0, 0),
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, GM1PerBlockGM11] is in LDS
|
||||
// b_mtx[KPerBlocl, GN1PerBlockGN11] is in LDS
|
||||
// c_mtx[GM1PerBlockGM11, GN1PerBlockGN11] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_gk0_bm_gk1_block_desc),
|
||||
decltype(b_gk0_bn_gk1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto c_bm0_bm1_bn0_bn1_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_bm0_bm1_bn0_bn1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_bm0_bm1_bn0_bn1_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_bm0_bm1_bn0_bn1_thread_desc),
|
||||
decltype(c_bm0_bm1_bn0_bn1_thread_tensor_lengths)>{}
|
||||
.Run(c_bm0_bm1_bn0_bn1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double + a_block_aligned_space_size,
|
||||
a_gk0_gm0_gm10_gm11_gk1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double + b_block_aligned_space_size,
|
||||
b_gk0_gn0_gn10_gn11_gk1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_bm0_bm1_bn0_bn1_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < GK0 - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_gk0_gm0_gm10_gm11_gk1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_gk0_gn0_gn10_gn11_gk1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_gk0_gm0_gm10_gm11_gk1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_gk0_gn0_gn10_gn11_gk1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_gk0_gm0_gm10_gm11_gk1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_gk0_gn0_gn10_gn11_gk1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_bm0_bm1_bn0_bn1_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M11 =
|
||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
||||
constexpr index_t N11 =
|
||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
||||
|
||||
constexpr index_t M10 = GM1PerBlockGM11 / M11;
|
||||
constexpr index_t N10 = GN1PerBlockGN11 / N11;
|
||||
|
||||
constexpr index_t M111 = M1PerThreadM111;
|
||||
constexpr index_t N111 = N1PerThreadN111;
|
||||
|
||||
constexpr auto c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1,
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0]>{},
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2]>{},
|
||||
Number<c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc),
|
||||
decltype(c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc),
|
||||
Sequence<1,
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I0],
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I2],
|
||||
c_bm0_bm1_bn0_bn1_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
make_multi_index(igm10,
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I0],
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I1],
|
||||
ign10,
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I2],
|
||||
c_bm0_bm1_bn0_bn1_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_gm10_bm0_bm1_gn10_bn0_bn1_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_gm10_bm0_bm1_gn10_bn0_bn1_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,669 @@
|
||||
#ifndef CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP
|
||||
#define CK_GRIDWISE_DYNAMIC_GEMM_V1R3_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_multi_index_transform_helper.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
#include "blockwise_gemm_v2r3.hpp"
|
||||
#include "blockwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_transfer_v2.hpp"
|
||||
#include "threadwise_dynamic_tensor_slice_set.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0M0M1K1GridDesc,
|
||||
typename BK0N0N1K1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc,
|
||||
const BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
|
||||
// pass tensor descriptor by __CONSTANT__ void pointer
|
||||
// __CONSTANT__ is needed to inform compiler void pointers in the kernel signature are pointing to
|
||||
// non-modifiable parameter address space, so compiler can enable corresponding optimization
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AK0M0M1K1GridDesc,
|
||||
typename BK0N0N1K1GridDesc,
|
||||
typename CM0M10M11N0N10N11GridDesc,
|
||||
typename CBlockIdToM0N0BlockClusterAdaptor,
|
||||
bool HasMainKBlockLoop,
|
||||
bool HasDoubleTailKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_dynamic_gemm_v1r3(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const void __CONSTANT__* p_a_k0_m0_m1_k1_grid_desc,
|
||||
const void __CONSTANT__* p_b_k0_n0_n1_k1_grid_desc,
|
||||
const void __CONSTANT__* p_c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const void __CONSTANT__* p_c_blockid_to_m0_n0_block_cluster_adaptor)
|
||||
{
|
||||
// first cast void __CONSTANT__ void* to void*
|
||||
// second cast void* to Desc*
|
||||
// the copy constructor of tensor descriptor doesn't take address_space(4)
|
||||
const auto a_k0_m0_m1_k1_grid_desc =
|
||||
*reinterpret_cast<const AK0M0M1K1GridDesc*>((const void*)p_a_k0_m0_m1_k1_grid_desc);
|
||||
const auto b_k0_n0_n1_k1_grid_desc =
|
||||
*reinterpret_cast<const BK0N0N1K1GridDesc*>((const void*)p_b_k0_n0_n1_k1_grid_desc);
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc =
|
||||
*reinterpret_cast<const CM0M10M11N0N10N11GridDesc*>(
|
||||
(const void*)p_c_m0_m10_m11_n0_n10_n11_grid_desc);
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
*reinterpret_cast<const CBlockIdToM0N0BlockClusterAdaptor*>(
|
||||
(const void*)p_c_blockid_to_m0_n0_block_cluster_adaptor);
|
||||
|
||||
constexpr index_t shared_block_size =
|
||||
GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB);
|
||||
|
||||
__shared__ FloatAB p_shared_block[shared_block_size];
|
||||
|
||||
GridwiseGemm::Run(p_a_grid,
|
||||
p_b_grid,
|
||||
p_c_grid,
|
||||
p_shared_block,
|
||||
a_k0_m0_m1_k1_grid_desc,
|
||||
b_k0_n0_n1_k1_grid_desc,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>{},
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>{});
|
||||
}
|
||||
#endif
|
||||
|
||||
template <index_t BlockSize,
|
||||
typename FloatAB,
|
||||
typename FloatAcc,
|
||||
typename FloatC,
|
||||
InMemoryDataOperation CGlobalMemoryDataOperation,
|
||||
typename AK0MK1GridDesc,
|
||||
typename BK0NK1GridDesc,
|
||||
typename CMNGridDesc,
|
||||
index_t MPerBlockM1,
|
||||
index_t NPerBlockN1,
|
||||
index_t KPerBlock,
|
||||
index_t M1PerThreadM111,
|
||||
index_t N1PerThreadN111,
|
||||
index_t KPerThread,
|
||||
index_t M11N11ThreadClusterM1100,
|
||||
index_t M11N11ThreadClusterN1100,
|
||||
index_t M11N11ThreadClusterM1101,
|
||||
index_t M11N11ThreadClusterN1101,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
typename ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename ABlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
typename BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename BBlockTransferSrcVectorTensorContiguousDimOrder,
|
||||
typename BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
typename CThreadTransferSrcDstAccessOrder,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector,
|
||||
typename AGridIteratorHacks,
|
||||
typename BGridIteratorHacks,
|
||||
typename CGridIteratorHacks,
|
||||
typename AGridMoveSliceWindowIteratorHacks,
|
||||
typename BGridMoveSliceWindowIteratorHacks>
|
||||
struct GridwiseDynamicGemm_km_kn_mn_v1r3
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
// K1 should be Number<...>
|
||||
static constexpr auto K1 = AK0MK1GridDesc{}.GetLength(I2);
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte()
|
||||
{
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto a_k_m_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
constexpr auto b_k_n_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size =
|
||||
math::integer_least_multiple(a_k_m_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size =
|
||||
math::integer_least_multiple(b_k_n_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
return 2 * (a_block_aligned_space_size + b_block_aligned_space_size) * sizeof(FloatAB);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc,
|
||||
const BK0NK1GridDesc& b_k0_n_k1_grid_desc,
|
||||
const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto K1 = a_k0_m_k1_grid_desc.GetLength(I2);
|
||||
|
||||
// TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc)
|
||||
|
||||
return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) &&
|
||||
K0 == b_k0_n_k1_grid_desc.GetLength(I0) &&
|
||||
K1 == b_k0_n_k1_grid_desc.GetLength(I2)) &&
|
||||
(M % MPerBlockM1 == 0 && N % NPerBlockN1 == 0 && K0 % KPerBlock == 0);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr index_t CalculateGridSize(index_t M, index_t N)
|
||||
{
|
||||
const index_t grid_size = (M / MPerBlockM1) * (N / NPerBlockN1);
|
||||
|
||||
return grid_size;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_main_k_block_loop = (K0 + KPerBlock) / (2 * KPerBlock) > 1;
|
||||
|
||||
return has_main_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasDoubleTailKBlockLoop(index_t K0)
|
||||
{
|
||||
const bool has_double_tail_k_block_loop = (K0 / KPerBlock) % 2 == 0;
|
||||
|
||||
return has_double_tail_k_block_loop;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeAK0M0M1K1GridDescriptor(const AK0MK1GridDesc& a_k0_m_k1_grid_desc)
|
||||
{
|
||||
const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0);
|
||||
const auto M = a_k0_m_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const auto M1 = Number<MPerBlockM1>{};
|
||||
const auto M0 = M / M1;
|
||||
|
||||
const auto a_k0_m0_m1_k1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
a_k0_m_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(M0, M1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return a_k0_m0_m1_k1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeBK0N0N1K1GridDescriptor(const BK0NK1GridDesc& b_k0_n_k1_grid_desc)
|
||||
{
|
||||
const auto K0 = b_k0_n_k1_grid_desc.GetLength(I0);
|
||||
const auto N = b_k0_n_k1_grid_desc.GetLength(I1);
|
||||
|
||||
const auto N1 = Number<NPerBlockN1>{};
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto b_k0_n0_n1_k1_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
b_k0_n_k1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K0),
|
||||
make_unmerge_transform(make_tuple(N0, N1)),
|
||||
make_pass_through_transform(K1)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{}));
|
||||
|
||||
return b_k0_n0_n1_k1_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCM0M10M11N0N10N11GridDescriptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
constexpr auto M11 =
|
||||
Number<M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101 * M1PerThreadM111>{};
|
||||
constexpr auto N11 =
|
||||
Number<M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101 * N1PerThreadN111>{};
|
||||
|
||||
constexpr auto M10 = M1 / M11;
|
||||
constexpr auto N10 = N1 / N11;
|
||||
|
||||
const auto c_m0_m10_m11_n0_n10_n11_grid_desc = transform_dynamic_tensor_descriptor(
|
||||
c_m_n_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(M0, M10, M11)),
|
||||
make_unmerge_transform(make_tuple(N0, N10, N11))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4, 5>{}));
|
||||
|
||||
return c_m0_m10_m11_n0_n10_n11_grid_desc;
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeCBlockIdToM0N0BlockClusterAdaptor(const CMNGridDesc& c_m_n_grid_desc)
|
||||
{
|
||||
const auto M = c_m_n_grid_desc.GetLength(I0);
|
||||
const auto N = c_m_n_grid_desc.GetLength(I1);
|
||||
|
||||
constexpr auto M1 = Number<MPerBlockM1>{};
|
||||
constexpr auto N1 = Number<NPerBlockN1>{};
|
||||
|
||||
const auto M0 = M / M1;
|
||||
const auto N0 = N / N1;
|
||||
|
||||
const auto c_blockid_to_m0_n0_block_cluster_adaptor =
|
||||
make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))),
|
||||
make_tuple(Sequence<0, 1>{}),
|
||||
make_tuple(Sequence<0>{}));
|
||||
|
||||
return c_blockid_to_m0_n0_block_cluster_adaptor;
|
||||
}
|
||||
|
||||
using AK0M0M1K1GridDesc = decltype(MakeAK0M0M1K1GridDescriptor(AK0MK1GridDesc{}));
|
||||
using BK0N0N1K1GridDesc = decltype(MakeBK0N0N1K1GridDescriptor(BK0NK1GridDesc{}));
|
||||
using CM0M10M11N0N10N11GridDesc = decltype(MakeCM0M10M11N0N10N11GridDescriptor(CMNGridDesc{}));
|
||||
using CBlockIdToM0N0BlockClusterAdaptor =
|
||||
decltype(MakeCBlockIdToM0N0BlockClusterAdaptor(CMNGridDesc{}));
|
||||
|
||||
template <bool HasMainKBlockLoop, bool HasDoubleTailKBlockLoop>
|
||||
__device__ static void
|
||||
Run(const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
FloatAB* __restrict__ p_shared_block,
|
||||
const AK0M0M1K1GridDesc& a_k0_m0_m1_k1_grid_desc,
|
||||
const BK0N0N1K1GridDesc& b_k0_n0_n1_k1_grid_desc,
|
||||
const CM0M10M11N0N10N11GridDesc& c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
const CBlockIdToM0N0BlockClusterAdaptor& c_blockid_to_m0_n0_block_cluster_adaptor,
|
||||
integral_constant<bool, HasMainKBlockLoop>,
|
||||
integral_constant<bool, HasDoubleTailKBlockLoop>)
|
||||
{
|
||||
const auto a_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_a_grid, a_k0_m0_m1_k1_grid_desc.GetElementSpaceSize());
|
||||
const auto b_global_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_b_grid, b_k0_n0_n1_k1_grid_desc.GetElementSpaceSize());
|
||||
auto c_grid_buf = make_dynamic_buffer<AddressSpace::Global>(
|
||||
p_c_grid, c_m0_m10_m11_n0_n10_n11_grid_desc.GetElementSpaceSize());
|
||||
|
||||
// divide block work by [M, N]
|
||||
const auto c_m0_n0_block_cluster_idx =
|
||||
c_blockid_to_m0_n0_block_cluster_adaptor.CalculateBottomIndex(
|
||||
make_multi_index(get_block_1d_id()));
|
||||
|
||||
// HACK: this force index data into SGPR
|
||||
const index_t im0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I0]);
|
||||
const index_t in0 = __builtin_amdgcn_readfirstlane(c_m0_n0_block_cluster_idx[I1]);
|
||||
|
||||
// TODO: change this. I think it needs multi-dimensional alignment
|
||||
constexpr auto max_lds_align = K1;
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto a_k0_m0_m1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto b_k0_n0_n1_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, I1, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// A matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto a_k0_m_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<MPerBlockM1>{}, K1), max_lds_align);
|
||||
|
||||
// TODO: check alignment
|
||||
// B matrix in LDS memory, for blockwise GEMM
|
||||
constexpr auto b_k0_n_k1_block_desc = make_dynamic_naive_tensor_descriptor_aligned_v2(
|
||||
make_tuple(Number<KPerBlock>{}, Number<NPerBlockN1>{}, K1), max_lds_align);
|
||||
|
||||
static_assert(a_k0_m0_m1_k1_block_desc.GetElementSpaceSize() ==
|
||||
a_k0_m_k1_block_desc.GetElementSpaceSize() &&
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize() ==
|
||||
b_k0_n_k1_block_desc.GetElementSpaceSize() &&
|
||||
"wrong!");
|
||||
|
||||
// A matrix blockwise copy
|
||||
auto a_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, 1, MPerBlockM1, K1.value>,
|
||||
ABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(a_k0_m0_m1_k1_grid_desc),
|
||||
decltype(a_k0_m0_m1_k1_block_desc),
|
||||
ABlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
ABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1, // SrcVectorTensorLengths
|
||||
ABlockTransferDstVectorTensorLengths_K0_M0_M1_K1, // DstVectorTensorLengths
|
||||
ABlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(a_k0_m0_m1_k1_grid_desc,
|
||||
make_multi_index(0, im0, 0, 0),
|
||||
a_k0_m0_m1_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// B matrix blockwise copy
|
||||
auto b_blockwise_copy = BlockwiseDynamicTensorSliceTransfer_v4r1<
|
||||
BlockSize,
|
||||
InMemoryDataOperation::Set,
|
||||
Sequence<KPerBlock, 1, NPerBlockN1, K1.value>,
|
||||
BBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
decltype(b_k0_n0_n1_k1_grid_desc),
|
||||
decltype(b_k0_n0_n1_k1_block_desc),
|
||||
BBlockTransferSrcAccessOrder,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
BBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1, // SrcVectorTensorLengths
|
||||
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, // DstVectorTensorLengths
|
||||
BBlockTransferSrcVectorTensorContiguousDimOrder, // SrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3>, // DstVectorTensorContiguousDimOrder
|
||||
false,
|
||||
true>(b_k0_n0_n1_k1_grid_desc,
|
||||
make_multi_index(0, in0, 0, 0),
|
||||
b_k0_n0_n1_k1_block_desc,
|
||||
make_multi_index(0, 0, 0, 0));
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[KPerBlock, MPerBlockM1] is in LDS
|
||||
// b_mtx[KPerBlocl, NPerBlockN1] is in LDS
|
||||
// c_mtx[MPerBlockM1, NPerBlockN1] is distributed among threads, and saved in
|
||||
// register
|
||||
const auto blockwise_gemm =
|
||||
BlockwiseGemm_k0mk1_k0nk1_m0m1n0n1_v2r3_pipeline_2x2<BlockSize,
|
||||
FloatAB,
|
||||
FloatAB,
|
||||
FloatAcc,
|
||||
decltype(a_k0_m_k1_block_desc),
|
||||
decltype(b_k0_n_k1_block_desc),
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111,
|
||||
KPerThread,
|
||||
M11N11ThreadClusterM1100,
|
||||
M11N11ThreadClusterN1100,
|
||||
M11N11ThreadClusterM1101,
|
||||
M11N11ThreadClusterN1101,
|
||||
M1PerThreadM111,
|
||||
N1PerThreadN111>{};
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_tensor_lengths =
|
||||
decltype(blockwise_gemm)::GetCM0M1N0N1ThreadTensorLengths();
|
||||
|
||||
constexpr auto c_m10_m11_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
sequence_to_tuple_of_number(c_m10_m11_n10_n11_thread_tensor_lengths));
|
||||
|
||||
// LDS allocation for A and B: be careful of alignment
|
||||
constexpr auto a_block_aligned_space_size = math::integer_least_multiple(
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
constexpr auto b_block_aligned_space_size = math::integer_least_multiple(
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize(), max_lds_align);
|
||||
|
||||
FloatAB* p_a_block_double = p_shared_block;
|
||||
FloatAB* p_b_block_double = p_shared_block + 2 * a_block_aligned_space_size;
|
||||
|
||||
// register allocation for output
|
||||
auto c_thread_buf = make_static_buffer<AddressSpace::Vgpr, FloatAcc>(
|
||||
c_m10_m11_n10_n11_thread_desc.GetElementSpaceSize());
|
||||
|
||||
ThreadwiseDynamicTensorSliceSet_v1<FloatAcc,
|
||||
decltype(c_m10_m11_n10_n11_thread_desc),
|
||||
decltype(c_m10_m11_n10_n11_thread_tensor_lengths)>{}
|
||||
.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
FloatAcc{0});
|
||||
|
||||
constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0, 0);
|
||||
|
||||
auto a_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_a_block_double, a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_even_buf = make_dynamic_buffer<AddressSpace::Lds>(
|
||||
p_b_block_double, b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
auto a_block_odd_buf =
|
||||
make_dynamic_buffer<AddressSpace::Lds>(p_a_block_double + a_block_aligned_space_size,
|
||||
a_k0_m0_m1_k1_block_desc.GetElementSpaceSize());
|
||||
auto b_block_odd_buf =
|
||||
make_dynamic_buffer<AddressSpace::Lds>(p_b_block_double + b_block_aligned_space_size,
|
||||
b_k0_n0_n1_k1_block_desc.GetElementSpaceSize());
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
}
|
||||
|
||||
if constexpr(HasMainKBlockLoop)
|
||||
{
|
||||
const auto K0 = a_k0_m0_m1_k1_grid_desc.GetLength(I0);
|
||||
|
||||
index_t k_block_data_begin = 0;
|
||||
|
||||
// LDS double buffer: main body
|
||||
// use Do-While loop instead of For loop to simplify control flow
|
||||
do
|
||||
{
|
||||
// even iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(c_m10_m11_n10_n11_thread_desc,
|
||||
a_block_even_buf,
|
||||
b_block_even_buf,
|
||||
c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
|
||||
|
||||
// odd iteration
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
a_blockwise_copy.RunRead(
|
||||
a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(
|
||||
b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_even_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_even_buf);
|
||||
|
||||
k_block_data_begin += 2 * KPerBlock;
|
||||
} while(k_block_data_begin < K0 - 2 * KPerBlock);
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
|
||||
{
|
||||
a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc,
|
||||
a_block_slice_copy_step,
|
||||
AGridMoveSliceWindowIteratorHacks{});
|
||||
b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc,
|
||||
b_block_slice_copy_step,
|
||||
BGridMoveSliceWindowIteratorHacks{});
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: load last data from device mem
|
||||
a_blockwise_copy.RunRead(a_k0_m0_m1_k1_grid_desc, a_global_buf, AGridIteratorHacks{});
|
||||
b_blockwise_copy.RunRead(b_k0_n0_n1_k1_grid_desc, b_global_buf, BGridIteratorHacks{});
|
||||
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
|
||||
// LDS double buffer: store last data to LDS
|
||||
a_blockwise_copy.RunWrite(a_k0_m0_m1_k1_block_desc, a_block_odd_buf);
|
||||
b_blockwise_copy.RunWrite(b_k0_n0_n1_k1_block_desc, b_block_odd_buf);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_odd_buf, b_block_odd_buf, c_thread_buf);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(
|
||||
c_m10_m11_n10_n11_thread_desc, a_block_even_buf, b_block_even_buf, c_thread_buf);
|
||||
}
|
||||
|
||||
// output: register to global memory
|
||||
{
|
||||
constexpr index_t M11 =
|
||||
M1PerThreadM111 * M11N11ThreadClusterM1100 * M11N11ThreadClusterM1101;
|
||||
constexpr index_t N11 =
|
||||
N1PerThreadN111 * M11N11ThreadClusterN1100 * M11N11ThreadClusterN1101;
|
||||
|
||||
constexpr index_t M10 = MPerBlockM1 / M11;
|
||||
constexpr index_t N10 = NPerBlockN1 / N11;
|
||||
|
||||
constexpr index_t M111 = M1PerThreadM111;
|
||||
constexpr index_t N111 = N1PerThreadN111;
|
||||
|
||||
constexpr auto c_m0_m10_m11_n0_n10_n11_thread_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(
|
||||
make_tuple(I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I0]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I1]>{},
|
||||
I1,
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I2]>{},
|
||||
Number<c_m10_m11_n10_n11_thread_tensor_lengths[I3]>{}));
|
||||
|
||||
const auto c_m10_m11_n10_n11_thread_origin_idx_on_block =
|
||||
blockwise_gemm.CalculateCM0M1N0N1ThreadOriginOnBlock(get_thread_local_1d_id());
|
||||
|
||||
ThreadwiseDynamicTensorSliceTransfer_v1r3<
|
||||
FloatAcc,
|
||||
FloatC,
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
|
||||
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
|
||||
Sequence<1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I0],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I1],
|
||||
1,
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I2],
|
||||
c_m10_m11_n10_n11_thread_tensor_lengths[I3]>,
|
||||
CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
CGlobalMemoryDataOperation,
|
||||
1,
|
||||
true>{c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
make_multi_index(im0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I0],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
|
||||
in0,
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
|
||||
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])}
|
||||
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
|
||||
make_tuple(I0, I0, I0, I0, I0, I0),
|
||||
c_thread_buf,
|
||||
c_m0_m10_m11_n0_n10_n11_grid_desc,
|
||||
c_grid_buf,
|
||||
CGridIteratorHacks{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,799 @@
|
||||
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
#define CK_THREADWISE_DYNAMIC_TENSOR_SLICE_TRANSFER_V2_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor_helper.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Assume:
|
||||
// 1. src_desc and dst_desc are not known at compile-time
|
||||
// 2. SrcBuffer and DstBuffer are DynamicBuffer
|
||||
// 3. src_slice_origin and dst_slice_origin are not known at compile-time,
|
||||
// 4. Use thread buffer
|
||||
template <typename SliceLengths,
|
||||
InMemoryDataOperation DstInMemOp,
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename DstVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename DstVectorTensorContiguousDimOrder,
|
||||
bool SrcResetCoordinateAfterRun, // control whether to move back src coordinate after each
|
||||
// RunRead(), will be fused with MoveSrcSliceWindow to
|
||||
// save addr computation
|
||||
bool DstResetCoordinateAfterRun> // control whether to move back dst coordinate after each
|
||||
// RunWrite(), will be fused with MoveDstSliceWindow to
|
||||
// save addr computation
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v3r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
using DstCoord = decltype(make_dynamic_tensor_coordinate(DstDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
using DstCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(DstDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v3r1(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin,
|
||||
const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin)
|
||||
: src_coord_(make_dynamic_tensor_coordinate(src_desc, src_slice_origin)),
|
||||
dst_coord_(make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin))
|
||||
{
|
||||
// TODO: fix this
|
||||
static_assert(is_same<SrcData, DstData>::value,
|
||||
"wrong! current implementation assume SrcData and DstData are same type");
|
||||
|
||||
static_for<0, nDim, 1>{}([](auto i) {
|
||||
static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0 &&
|
||||
SliceLengths::At(i) % DstVectorTensorLengths::At(i) == 0,
|
||||
"wrong!");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx)
|
||||
{
|
||||
src_coord_ = make_dynamic_tensor_coordinate(src_desc, src_slice_origin_idx);
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(const DstDesc& dst_desc, const Index& dst_slice_origin_idx)
|
||||
{
|
||||
dst_coord_ = make_dynamic_tensor_coordinate(dst_desc, dst_slice_origin_idx);
|
||||
}
|
||||
|
||||
template <typename SrcBuffer, typename SrcIteratorHacks>
|
||||
__device__ void RunRead(const SrcDesc& src_desc,
|
||||
const SrcBuffer& src_buf,
|
||||
const SrcIteratorHacks& src_iterator_hacks)
|
||||
{
|
||||
static_assert(SrcBuffer::GetAddressSpace() == AddressSpace::Global or
|
||||
SrcBuffer::GetAddressSpace() == AddressSpace::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value,
|
||||
"wrong! SrcBuffer and SrcData data type are inconsistent");
|
||||
|
||||
// tensor descriptor for src_vector
|
||||
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
|
||||
|
||||
constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(src_vector_tensor_lengths,
|
||||
SrcVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies_v2{},
|
||||
I1),
|
||||
SrcVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
||||
sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||
|
||||
// access order and lengths
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto src_forward_iterators = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? src_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, forward_step, src_iterator_hacks[I0][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto src_backward_iterators = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -src_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
return make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, backward_step, src_iterator_hacks[I1][i]);
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_src_access_lengths)>{}([&](auto ordered_src_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_idx[I0];
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
}();
|
||||
|
||||
// calculate src data index
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_idx[i]
|
||||
: ordered_src_access_lengths[i] - 1 -
|
||||
ordered_src_access_idx[i];
|
||||
});
|
||||
|
||||
auto src_data_idx =
|
||||
container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
|
||||
return src_data_idx;
|
||||
}();
|
||||
|
||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_vector)::type;
|
||||
|
||||
const bool is_src_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
|
||||
|
||||
// copy data from src_buf to src_vector
|
||||
src_vector.template AsType<src_vector_t>()(I0) =
|
||||
src_buf.template Get<src_vector_t>(src_coord_.GetOffset(), is_src_valid);
|
||||
|
||||
// copy data from src_vector to buffer_
|
||||
static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
|
||||
constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
|
||||
|
||||
constexpr index_t src_vector_offset =
|
||||
src_vector_desc.CalculateOffset(src_vector_idx);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc_.CalculateOffset(src_data_idx + src_vector_idx);
|
||||
|
||||
buffer_(Number<buffer_offset>{}) =
|
||||
src_vector.template AsType<SrcData>()[Number<src_vector_offset>{}];
|
||||
});
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim(i) = ordered_src_access_idx[i] < ordered_src_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim(i) &=
|
||||
ordered_src_access_idx[j] == ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim;
|
||||
}
|
||||
();
|
||||
|
||||
// move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_forward_iterators[src_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_coord_, src_backward_iterators[src_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move src coordinate back to slice origin (or not)
|
||||
if constexpr(SrcResetCoordinateAfterRun)
|
||||
{
|
||||
const auto src_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, GetSrcCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, src_reset_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename DstBuffer, typename DstIteratorHacks>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc,
|
||||
DstBuffer& dst_buf,
|
||||
const DstIteratorHacks& dst_iterator_hacks)
|
||||
{
|
||||
static_assert(DstBuffer::GetAddressSpace() == AddressSpace::Global or
|
||||
DstBuffer::GetAddressSpace() == AddressSpace::Lds,
|
||||
"wrong!");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
// tensor descriptor for dst_vector
|
||||
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
|
||||
|
||||
constexpr auto dst_vector_tensor_strides = container_reorder_given_old2new(
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(dst_vector_tensor_lengths,
|
||||
DstVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies_v2{},
|
||||
I1),
|
||||
DstVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto dst_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
||||
sequence_to_tuple_of_number(dst_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(dst_vector_tensor_strides));
|
||||
|
||||
// dst access order and lengths
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// make forward iterators
|
||||
const auto dst_forward_iterators = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index forward_step;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
forward_step(j) = (i.value == j.value) ? dst_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
const auto forward_iterator = make_dynamic_tensor_coordinate_iterator(
|
||||
dst_desc, forward_step, dst_iterator_hacks[I0][i]);
|
||||
|
||||
return forward_iterator;
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// make backward iterators
|
||||
const auto dst_backward_iterators = generate_tuple(
|
||||
[&](auto i) {
|
||||
Index backward_step;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto j) {
|
||||
backward_step(j) = (i.value == j.value) ? -dst_vector_tensor_lengths[i] : 0;
|
||||
});
|
||||
|
||||
const auto backward_iterator = make_dynamic_tensor_coordinate_iterator(
|
||||
dst_desc, backward_step, dst_iterator_hacks[I1][i]);
|
||||
|
||||
return backward_iterator;
|
||||
},
|
||||
Number<nDim>{});
|
||||
|
||||
// loop over tensor and copy
|
||||
static_ford<decltype(ordered_dst_access_lengths)>{}([&](auto ordered_dst_access_idx) {
|
||||
// judge move forward or move backward
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_idx[I0];
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_idx[j];
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
}();
|
||||
|
||||
// calculate dst data index
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_idx[i]
|
||||
: ordered_dst_access_lengths[i] - 1 -
|
||||
ordered_dst_access_idx[i];
|
||||
});
|
||||
|
||||
auto dst_data_idx =
|
||||
container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_vector_tensor_lengths;
|
||||
|
||||
return dst_data_idx;
|
||||
}();
|
||||
|
||||
vector_type_maker_t<DstData, dst_vector_desc.GetElementSpaceSize()> dst_vector;
|
||||
|
||||
// copy data from buffer_ to dst_vector (also cast from SrcData to DstData)
|
||||
static_ford<DstVectorTensorLengths>{}([&](auto dst_vector_idx_) {
|
||||
constexpr auto dst_vector_idx = to_multi_index(dst_vector_idx_);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc_.CalculateOffset(dst_data_idx + dst_vector_idx);
|
||||
|
||||
constexpr index_t dst_vector_offset =
|
||||
dst_vector_desc.CalculateOffset(dst_vector_idx);
|
||||
|
||||
dst_vector.template AsType<DstData>()(Number<dst_vector_offset>{}) =
|
||||
type_convert<DstData>{}(buffer_[Number<buffer_offset>{}]);
|
||||
});
|
||||
|
||||
using dst_vector_t = typename decltype(dst_vector)::type;
|
||||
|
||||
// copy data from dst_vector to dst_buf
|
||||
const bool is_dst_valid =
|
||||
coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_);
|
||||
|
||||
dst_buf.template Set<dst_vector_t>(
|
||||
dst_coord_.GetOffset(),
|
||||
is_dst_valid,
|
||||
dst_vector.template AsType<dst_vector_t>()[Number<0>{}]);
|
||||
|
||||
constexpr auto move_on_dim = [&]() constexpr
|
||||
{
|
||||
StaticallyIndexedArray<bool, nDim> move_on_dim;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
move_on_dim(i) = ordered_dst_access_idx[i] < ordered_dst_access_lengths[i] - 1;
|
||||
|
||||
static_for<i + 1, nDim, 1>{}([&](auto j) {
|
||||
move_on_dim(i) &=
|
||||
ordered_dst_access_idx[j] == ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
});
|
||||
|
||||
return move_on_dim;
|
||||
}
|
||||
();
|
||||
|
||||
// move
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
if constexpr(move_on_dim[i])
|
||||
{
|
||||
if constexpr(forward_sweep[i])
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_forward_iterators[dst_dim_access_order[i]]);
|
||||
}
|
||||
else
|
||||
{
|
||||
move_dynamic_tensor_coordinate(
|
||||
dst_desc, dst_coord_, dst_backward_iterators[dst_dim_access_order[i]]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// move dst coordinate back to slice origin (or not)
|
||||
if constexpr(DstResetCoordinateAfterRun)
|
||||
{
|
||||
const auto dst_reset_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, GetDstCoordinateResetStep());
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, dst_reset_iterator);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcBuffer>
|
||||
__device__ void RunRead(const SrcDesc& src_desc, const SrcBuffer& src_buf)
|
||||
{
|
||||
constexpr index_t ntransform_src = SrcDesc::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_src, 0>::type{};
|
||||
|
||||
constexpr auto src_iterator_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunRead(src_desc, src_buf, src_iterator_hacks);
|
||||
}
|
||||
|
||||
template <typename DstBuffer>
|
||||
__device__ void RunWrite(const DstDesc& dst_desc, DstBuffer& dst_buf)
|
||||
{
|
||||
constexpr index_t ntransform_dst = DstDesc::GetNumOfTransform();
|
||||
|
||||
constexpr auto zeros = typename uniform_sequence_gen<ntransform_dst, 0>::type{};
|
||||
|
||||
constexpr auto dst_iterator_hacks =
|
||||
make_tuple(generate_tuple([&](auto) { return zeros; }, Number<nDim>{}),
|
||||
generate_tuple([&](auto) { return zeros; }, Number<nDim>{}));
|
||||
|
||||
RunWrite(dst_desc, dst_buf, dst_iterator_hacks);
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetSrcCoordinateResetStep()
|
||||
{
|
||||
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||
|
||||
constexpr auto src_dim_access_order = SrcDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_src_access_lengths =
|
||||
container_reorder_given_new2old(src_access_lengths, src_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_src_access_lengths[I0] - 1;
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_src_access_lengths[j] + ordered_src_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
}();
|
||||
|
||||
// calculate src data index after last iteration in RunRead(), if it has not being reset by
|
||||
// RunRead()
|
||||
constexpr auto src_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_src_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
auto src_data_idx = container_reorder_given_old2new(ordered_idx, src_dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
|
||||
return src_data_idx;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_src_data_step = [&]() {
|
||||
Index reset_src_data_step;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_src_data_step(i) = -src_data_idx[i]; });
|
||||
|
||||
return reset_src_data_step;
|
||||
}();
|
||||
|
||||
return reset_src_data_step;
|
||||
}
|
||||
|
||||
__device__ static constexpr auto GetDstCoordinateResetStep()
|
||||
{
|
||||
constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{};
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths{} / dst_vector_tensor_lengths;
|
||||
|
||||
constexpr auto dst_dim_access_order = DstDimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_dst_access_lengths =
|
||||
container_reorder_given_new2old(dst_access_lengths, dst_dim_access_order);
|
||||
|
||||
// judge move forward or move backward during the last iteration
|
||||
constexpr auto forward_sweep = [&]() {
|
||||
StaticallyIndexedArray<bool, nDim> forward_sweep;
|
||||
|
||||
forward_sweep(I0) = true;
|
||||
|
||||
static_for<1, nDim, 1>{}([&](auto i) {
|
||||
index_t tmp = ordered_dst_access_lengths[I0] - 1;
|
||||
|
||||
static_for<0, i, 1>{}([&](auto j) {
|
||||
tmp = tmp * ordered_dst_access_lengths[j] + ordered_dst_access_lengths[j] - 1;
|
||||
});
|
||||
|
||||
forward_sweep(i) = tmp % 2 == 0;
|
||||
});
|
||||
|
||||
return forward_sweep;
|
||||
}();
|
||||
|
||||
// calculate dst data index after last iteration in RunWrite(), if it has not being reset by
|
||||
// RunWrite()
|
||||
constexpr auto dst_data_idx = [&]() {
|
||||
Index ordered_idx;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
ordered_idx(i) = forward_sweep[i] ? ordered_dst_access_lengths[i] - 1 : 0;
|
||||
});
|
||||
|
||||
auto dst_data_idx = container_reorder_given_old2new(ordered_idx, dst_dim_access_order) *
|
||||
dst_vector_tensor_lengths;
|
||||
|
||||
return dst_data_idx;
|
||||
}();
|
||||
|
||||
//
|
||||
constexpr auto reset_dst_data_step = [&]() {
|
||||
Index reset_dst_data_step;
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) { reset_dst_data_step(i) = -dst_data_idx[i]; });
|
||||
|
||||
return reset_dst_data_step;
|
||||
}();
|
||||
|
||||
return reset_dst_data_step;
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
// src_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
template <typename SrcMoveSliceWindowIteratorHack>
|
||||
__device__ void
|
||||
MoveSrcSliceWindow(const SrcDesc& src_desc,
|
||||
const Index& src_slice_origin_step_idx,
|
||||
const SrcMoveSliceWindowIteratorHack& src_move_slice_window_iterator_hack)
|
||||
{
|
||||
// if src coord was not reset by RunRead(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
SrcResetCoordinateAfterRun ? src_slice_origin_step_idx
|
||||
: src_slice_origin_step_idx + GetSrcCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step = make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, adjusted_step_idx, src_move_slice_window_iterator_hack);
|
||||
|
||||
move_dynamic_tensor_coordinate(src_desc, src_coord_, adjusted_step);
|
||||
}
|
||||
// dst_slice_origin_step_idx need to be known at compile-time, for performance reason
|
||||
__device__ void MoveDstSliceWindow(const DstDesc& dst_desc,
|
||||
const Index& dst_slice_origin_step_idx)
|
||||
{
|
||||
// if dst coord was not reset by RunWrite(), then need to adjust the step here
|
||||
const auto adjusted_step_idx =
|
||||
DstResetCoordinateAfterRun ? dst_slice_origin_step_idx
|
||||
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
|
||||
|
||||
// is it OK to construct a new step every time?
|
||||
const auto adjusted_step =
|
||||
make_dynamic_tensor_coordinate_iterator(dst_desc, adjusted_step_idx);
|
||||
|
||||
move_dynamic_tensor_coordinate(dst_desc, dst_coord_, adjusted_step);
|
||||
}
|
||||
|
||||
private:
|
||||
static constexpr auto buffer_desc_ =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(sequence_to_tuple_of_number(SliceLengths{}));
|
||||
|
||||
static constexpr auto buffer_size_ = buffer_desc_.GetElementSpaceSize();
|
||||
|
||||
StaticBuffer<AddressSpace::Vgpr, SrcData, buffer_size_> buffer_;
|
||||
|
||||
SrcCoord src_coord_;
|
||||
DstCoord dst_coord_;
|
||||
};
|
||||
|
||||
// Assume:
|
||||
// 1. src:
|
||||
// 1. SrcDesc is known at compile-time
|
||||
// 2. SrcBuffer is DynamicBuffer
|
||||
// 3. src_ref_idx is known at run-time
|
||||
// 4. SrcRefToOriginDisplacement is known at compile-time
|
||||
// 5. use #-iterator
|
||||
// 2. dst:
|
||||
// 1. DstDesc is known at compile-time
|
||||
// 2. DstBuffer is StaticBuffer
|
||||
// 3. DstOriginIdx is known at compile-time
|
||||
// 4. use direct address calculation
|
||||
// 3. vector access on src
|
||||
template <
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseDynamicTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
using SrcCoord = decltype(make_dynamic_tensor_coordinate(SrcDesc{}, Index{}));
|
||||
|
||||
using SrcCoordIterator = decltype(make_dynamic_tensor_coordinate_iterator(SrcDesc{}, Index{}));
|
||||
|
||||
__device__ constexpr ThreadwiseDynamicTensorSliceTransfer_v4r1(const Index& src_ref_idx)
|
||||
: src_ref_coord_(make_dynamic_tensor_coordinate(SrcDesc{}, src_ref_idx))
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_for<0, nDim, 1>{}([](auto i) {
|
||||
static_assert(SliceLengths::At(i) % SrcVectorTensorLengths::At(i) == 0, "wrong!");
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcRefToOriginDisplacement,
|
||||
typename DstOriginIdx,
|
||||
typename SrcBuffer,
|
||||
typename DstBuffer>
|
||||
__device__ void Run(const SrcDesc&,
|
||||
const SrcRefToOriginDisplacement&,
|
||||
const SrcBuffer& src_buf,
|
||||
const DstDesc&,
|
||||
const DstOriginIdx&,
|
||||
DstBuffer& dst_buf) const
|
||||
{
|
||||
static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
"wrong! SrcDesc and DstDesc need to known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename SrcBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<SrcData>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename DstBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<DstData>>>::value,
|
||||
"wrong! SrcBuffer or DstBuffer data type is wrong");
|
||||
|
||||
static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<
|
||||
remove_cv_t<remove_reference_t<SrcRefToOriginDisplacement>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<DstOriginIdx>>>::value,
|
||||
"wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known "
|
||||
"at compile-time");
|
||||
|
||||
// SrcDesc and DstDesc are known at compile-time
|
||||
constexpr auto src_desc = remove_cv_t<remove_reference_t<SrcDesc>>{};
|
||||
constexpr auto dst_desc = remove_cv_t<remove_reference_t<DstDesc>>{};
|
||||
|
||||
// SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time
|
||||
constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{});
|
||||
constexpr auto dst_origin_idx = to_multi_index(DstOriginIdx{});
|
||||
|
||||
// tensor descriptor for src_vector
|
||||
constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{};
|
||||
|
||||
constexpr auto src_vector_tensor_strides = container_reorder_given_old2new(
|
||||
container_reverse_exclusive_scan(
|
||||
container_reorder_given_new2old(src_vector_tensor_lengths,
|
||||
SrcVectorTensorContiguousDimOrder{}),
|
||||
math::multiplies_v2{},
|
||||
I1),
|
||||
SrcVectorTensorContiguousDimOrder{});
|
||||
|
||||
constexpr auto src_vector_desc = make_dynamic_naive_tensor_descriptor_v2(
|
||||
sequence_to_tuple_of_number(src_vector_tensor_lengths),
|
||||
sequence_to_tuple_of_number(src_vector_tensor_strides));
|
||||
|
||||
// access order and lengths
|
||||
constexpr auto access_lengths = SliceLengths{} / src_vector_tensor_lengths;
|
||||
|
||||
constexpr auto dim_access_order = DimAccessOrder{};
|
||||
|
||||
constexpr auto ordered_access_lengths =
|
||||
container_reorder_given_new2old(access_lengths, dim_access_order);
|
||||
|
||||
static_ford<decltype(ordered_access_lengths)>{}([&](auto ordered_access_idx) {
|
||||
// position in slice window
|
||||
constexpr auto data_to_origin_disp_idx =
|
||||
ordered_access_idx.ReorderGivenOld2New(dim_access_order) *
|
||||
src_vector_tensor_lengths;
|
||||
|
||||
// src coordinate at starting point of src_vector
|
||||
constexpr auto src_ref_to_data_disp_idx =
|
||||
src_ref_to_origin_disp_idx + data_to_origin_disp_idx;
|
||||
|
||||
constexpr auto src_ref_to_data_disp_coord_iterator =
|
||||
make_dynamic_tensor_coordinate_iterator(src_desc, src_ref_to_data_disp_idx);
|
||||
|
||||
auto src_data_coord = src_ref_coord_;
|
||||
|
||||
move_dynamic_tensor_coordinate(
|
||||
src_desc, src_data_coord, src_ref_to_data_disp_coord_iterator);
|
||||
|
||||
vector_type_maker_t<SrcData, src_vector_desc.GetElementSpaceSize()> src_vector;
|
||||
|
||||
using src_vector_t = typename decltype(src_vector)::type;
|
||||
|
||||
const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
|
||||
src_desc, src_data_coord);
|
||||
|
||||
// copy data from src_buf into src_vector
|
||||
src_vector.template AsType<src_vector_t>()(I0) =
|
||||
src_buf.template Get<src_vector_t>(src_data_coord.GetOffset(), is_src_valid);
|
||||
|
||||
// copy data from src_vector into dst_buf (also cast from SrcData to DstData)
|
||||
static_ford<SrcVectorTensorLengths>{}([&](auto src_vector_idx_) {
|
||||
constexpr auto src_vector_idx = to_multi_index(src_vector_idx_);
|
||||
|
||||
constexpr index_t src_vector_offset =
|
||||
src_vector_desc.CalculateOffset(src_vector_idx);
|
||||
|
||||
constexpr index_t dst_offset = dst_desc.CalculateOffset(
|
||||
dst_origin_idx + data_to_origin_disp_idx + src_vector_idx);
|
||||
|
||||
dst_buf(Number<dst_offset>{}) = type_convert<DstData>{}(
|
||||
src_vector.template AsType<DstData>()[Number<src_vector_offset>{}]);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcSliceMoveStepIdx>
|
||||
__device__ void MoveSrcSliceWindow(const SrcDesc&,
|
||||
const SrcSliceMoveStepIdx& src_slice_move_step_idx)
|
||||
{
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
|
||||
const auto src_slice_move_step_iter = make_dynamic_tensor_coordinate_iterator(
|
||||
src_desc, to_multi_index(src_slice_move_step_idx));
|
||||
|
||||
move_dynamic_tensor_coordinate(SrcDesc{}, src_ref_coord_, src_slice_move_step_iter);
|
||||
}
|
||||
|
||||
private:
|
||||
SrcCoord src_ref_coord_;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -6,140 +6,6 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// C[M, N] += transpose(A[K, M]) * B[K, N]
|
||||
// Element of matrix can be vectorized data
|
||||
// Assume:
|
||||
// 1. ADesc, BDesc, CDesc are known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_km_kn_mn_v1r1
|
||||
{
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto M = CDesc{}.GetLength(I0);
|
||||
constexpr auto N = CDesc{}.GetLength(I1);
|
||||
constexpr auto K = ADesc{}.GetLength(I0);
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, K, 1>{}([&](auto k) {
|
||||
static_for<0, M, 1>{}([&](auto m) {
|
||||
constexpr index_t a_offset =
|
||||
ADesc{}.CalculateOffset(a_origin_idx + make_tuple(k, m));
|
||||
|
||||
#if 0
|
||||
if constexpr(N == 2)
|
||||
{
|
||||
constexpr index_t b_offset_0 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
|
||||
constexpr index_t b_offset_1 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
|
||||
|
||||
constexpr index_t c_offset_0 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
|
||||
constexpr index_t c_offset_1 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
|
||||
|
||||
amd_assembly_outer_product_1x2(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset_0>{}],
|
||||
b_buf[Number<b_offset_1>{}],
|
||||
c_buf(Number<c_offset_0>{}),
|
||||
c_buf(Number<c_offset_1>{}));
|
||||
}
|
||||
else if constexpr(N == 4)
|
||||
{
|
||||
constexpr index_t b_offset_0 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
|
||||
constexpr index_t b_offset_1 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
|
||||
constexpr index_t b_offset_2 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
|
||||
constexpr index_t b_offset_3 =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
|
||||
|
||||
constexpr index_t c_offset_0 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
|
||||
constexpr index_t c_offset_1 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
|
||||
constexpr index_t c_offset_2 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
|
||||
constexpr index_t c_offset_3 =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
|
||||
|
||||
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset_0>{}],
|
||||
b_buf[Number<b_offset_1>{}],
|
||||
b_buf[Number<b_offset_2>{}],
|
||||
b_buf[Number<b_offset_3>{}],
|
||||
c_buf(Number<c_offset_0>{}),
|
||||
c_buf(Number<c_offset_1>{}),
|
||||
c_buf(Number<c_offset_2>{}),
|
||||
c_buf(Number<c_offset_3>{}));
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
static_for<0, N, 1>{}([&](auto n) {
|
||||
|
||||
constexpr index_t b_offset =
|
||||
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, n));
|
||||
constexpr index_t c_offset =
|
||||
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, n));
|
||||
|
||||
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
|
||||
// Tensor element can be vectorized data
|
||||
// Assume:
|
||||
@@ -227,9 +93,124 @@ struct ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
|
||||
constexpr index_t c_offset = CDesc{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(m0, m1, n0, n1));
|
||||
|
||||
amd_assembly_inner_product(a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
amd_inner_product_dlop<FloatA, FloatB, FloatC>(
|
||||
a_buf[Number<a_offset>{}],
|
||||
b_buf[Number<b_offset>{}],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// C[M0, M1, N0, N1] += A[K0, M0, M1, K1] * B[K0, N0, N1, K1]
|
||||
// Tensor element can be vectorized data
|
||||
// Assume:
|
||||
// 1. ADesc, BDesc, CDesc are known at compile-time
|
||||
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
|
||||
template <typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename ADesc,
|
||||
typename BDesc,
|
||||
typename CDesc,
|
||||
typename KLengths,
|
||||
typename MLengths,
|
||||
typename NLengths,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemm_k0m0m1k1_k0n0n1k1_m0m1n0n1()
|
||||
{
|
||||
static_assert(ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
"wrong! Desc should be known at compile-time");
|
||||
|
||||
// TODO: sanity-check: compare ADesc, BDesc, CDesc Size with KLenghts, MLengths and NLengths
|
||||
|
||||
// TODO remove this restriction
|
||||
static_assert(KLengths::Size() == 2 && MLengths::Size() == 2 && NLengths::Size() == 2,
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
template <typename ABuffer,
|
||||
typename AOriginIdx,
|
||||
typename BBuffer,
|
||||
typename BOriginIdx,
|
||||
typename CBuffer,
|
||||
typename COriginIdx>
|
||||
__device__ static void Run(const ABuffer& a_buf,
|
||||
AOriginIdx,
|
||||
const BBuffer& b_buf,
|
||||
BOriginIdx,
|
||||
CBuffer& c_buf,
|
||||
COriginIdx)
|
||||
{
|
||||
static_assert(
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<AOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<BOriginIdx>>>::value &&
|
||||
is_known_at_compile_time<remove_cv_t<remove_reference_t<COriginIdx>>>::value,
|
||||
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time");
|
||||
|
||||
static_assert(is_same<remove_cv_t<remove_reference_t<typename ABuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatA>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename BBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatB>>>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<typename CBuffer::type>>,
|
||||
remove_cv_t<remove_reference_t<FloatC>>>::value &&
|
||||
"wrong! inconsistent type");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t K0 = KLengths{}[I0];
|
||||
constexpr index_t K1 = KLengths{}[I1];
|
||||
constexpr index_t M0 = MLengths{}[I0];
|
||||
constexpr index_t M1 = MLengths{}[I1];
|
||||
constexpr index_t N0 = NLengths{}[I0];
|
||||
constexpr index_t N1 = NLengths{}[I1];
|
||||
|
||||
constexpr auto a_origin_idx = to_multi_index(AOriginIdx{});
|
||||
constexpr auto b_origin_idx = to_multi_index(BOriginIdx{});
|
||||
constexpr auto c_origin_idx = to_multi_index(COriginIdx{});
|
||||
|
||||
static_for<0, K0, 1>{}([&](auto k0) {
|
||||
static_for<0, M0, 1>{}([&](auto m0) {
|
||||
static_for<0, M1, 1>{}([&](auto m1) {
|
||||
static_for<0, N0, 1>{}([&](auto n0) {
|
||||
static_for<0, N1, 1>{}([&](auto n1) {
|
||||
|
||||
vector_type<FloatA, K1> a_vec;
|
||||
vector_type<FloatB, K1> b_vec;
|
||||
|
||||
static_for<0, K1, 1>{}([&](auto k1) {
|
||||
constexpr index_t a_offset = ADesc{}.CalculateOffset(
|
||||
a_origin_idx + make_multi_index(k0, m0, m1, k1));
|
||||
|
||||
constexpr index_t b_offset = BDesc{}.CalculateOffset(
|
||||
b_origin_idx + make_multi_index(k0, n0, n1, k1));
|
||||
|
||||
a_vec.template AsType<FloatA>()(k1) = a_buf[Number<a_offset>{}];
|
||||
|
||||
b_vec.template AsType<FloatB>()(k1) = b_buf[Number<b_offset>{}];
|
||||
});
|
||||
|
||||
using a_vector_t = typename vector_type<FloatA, K1>::type;
|
||||
using b_vector_t = typename vector_type<FloatB, K1>::type;
|
||||
|
||||
constexpr index_t c_offset = CDesc{}.CalculateOffset(
|
||||
c_origin_idx + make_multi_index(m0, m1, n0, n1));
|
||||
|
||||
amd_inner_product_dlop<a_vector_t, b_vector_t, FloatC>(
|
||||
a_vec.template AsType<a_vector_t>()[I0],
|
||||
b_vec.template AsType<b_vector_t>()[I0],
|
||||
c_buf(Number<c_offset>{}));
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
146
composable_kernel/include/utility/amd_dlop.hpp
Normal file
146
composable_kernel/include/utility/amd_dlop.hpp
Normal file
@@ -0,0 +1,146 @@
|
||||
#ifndef CK_AMD_DLOP_HPP
|
||||
#define CK_AMD_DLOP_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename TA, typename TB, typename TC>
|
||||
__device__ void amd_inner_product_dlop(const TA& a, const TB& b, TC& c);
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<float, float, float>(const float& a, const float& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c += a * b;
|
||||
#endif
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_DLOP
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<half2_t, half2_t, float>(const half2_t& a, const half2_t& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot2(a, b, c, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<half4_t, half4_t, float>(const half4_t& a, const half4_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I0],
|
||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 4>{a}.AsType<half2_t>()[I1],
|
||||
vector_type<half_t, 4>{b}.AsType<half2_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void
|
||||
amd_inner_product_dlop<half8_t, half8_t, float>(const half8_t& a, const half8_t& b, float& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I0],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I1],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I1],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I2],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I2],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<half_t, 8>{a}.AsType<half2_t>()[I3],
|
||||
vector_type<half_t, 8>{b}.AsType<half2_t>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_inner_product_dlop<int8x4_t, int8x4_t, int32_t>(const int8x4_t& a,
|
||||
const int8x4_t& b,
|
||||
int32_t& c)
|
||||
{
|
||||
#if CK_USE_AMD_DLOP_INLINE_ASM
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_inner_product_dlop<int8x8_t, int8x8_t, int32_t>(const int8x8_t& a,
|
||||
const int8x8_t& b,
|
||||
int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void amd_inner_product_dlop<int8x16_t, int8x16_t, int32_t>(const int8x16_t& a,
|
||||
const int8x16_t& b,
|
||||
int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
|
||||
c);
|
||||
|
||||
amd_inner_product_dlop(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
|
||||
c);
|
||||
}
|
||||
#endif // CK_USE_AMD_DLOP
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -5,94 +5,16 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// c += inner_product(a, b)
|
||||
__device__ void amd_assembly_inner_product(const float& a, const float& b, float& c)
|
||||
{
|
||||
#if CK_USE_AMD_V_FMAC_F32
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %1, %2 \n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(a), "v"(b), "0"(c));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_inner_product(const int8x4_t& a, const int8x4_t& b, int32_t& c)
|
||||
{
|
||||
#if 1
|
||||
asm volatile("\n \
|
||||
v_dot4_i32_i8 %0, %1, %2, %0\n \
|
||||
"
|
||||
: "=v"(c)
|
||||
: "v"(as_type<int32_t>(a)), "v"(as_type<int32_t>(b)), "0"(c));
|
||||
#else
|
||||
c = __builtin_amdgcn_sdot4(as_type<int32_t>(a), as_type<int32_t>(b), c, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_inner_product(const int8x8_t& a, const int8x8_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 8>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 8>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
}
|
||||
|
||||
__device__ void amd_assembly_inner_product(const int8x16_t& a, const int8x16_t& b, int32_t& c)
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I0],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I0],
|
||||
c);
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I1],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I1],
|
||||
c);
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I2],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I2],
|
||||
c);
|
||||
|
||||
amd_assembly_inner_product(vector_type<int8_t, 16>{a}.AsType<int8x4_t>()[I3],
|
||||
vector_type<int8_t, 16>{b}.AsType<int8x4_t>()[I3],
|
||||
c);
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
// c1 += inner_product(a, b1)
|
||||
__device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
||||
{
|
||||
#if CK_USE_AMD_V_FMAC_F32
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %2, %3 \n \
|
||||
v_fmac_f32 %1, %2, %4 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %2, %3 \n \
|
||||
v_mac_f32 %1, %2, %4 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
#endif
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
@@ -102,7 +24,6 @@ __device__ void amd_assembly_outer_product_1x2(float a, float b0, float b1, floa
|
||||
__device__ void amd_assembly_outer_product_1x4(
|
||||
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
|
||||
{
|
||||
#if CK_USE_AMD_V_FMAC_F32
|
||||
asm volatile("\n \
|
||||
v_fmac_f32 %0, %4, %5 \n \
|
||||
v_fmac_f32 %1, %4, %6 \n \
|
||||
@@ -111,16 +32,6 @@ __device__ void amd_assembly_outer_product_1x4(
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
v_mac_f32 %1, %4, %6 \n \
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
||||
#endif
|
||||
}
|
||||
|
||||
// c0 += inner_product(a, b0)
|
||||
|
||||
@@ -28,10 +28,15 @@
|
||||
#include "static_buffer.hpp"
|
||||
#include "dynamic_buffer.hpp"
|
||||
|
||||
// TODO: remove this
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
#include "amd_inline_asm.hpp"
|
||||
#endif
|
||||
|
||||
#if CK_USE_AMD_DLOP
|
||||
#include "amd_dlop.hpp"
|
||||
#endif
|
||||
|
||||
#if CK_USE_AMD_XDLOPS
|
||||
#include "amd_xdlops.hpp"
|
||||
#include "amd_xdlops_inline_asm.hpp"
|
||||
|
||||
@@ -54,8 +54,13 @@
|
||||
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_V_FMAC_F32
|
||||
#define CK_USE_AMD_V_FMAC_F32 1
|
||||
// AMD DLOPS
|
||||
#ifndef CK_USE_AMD_DLOP
|
||||
#define CK_USE_AMD_DLOP 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_DLOP_INLINE_ASM
|
||||
#define CK_USE_AMD_DLOP_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
// AMD buffer addressing
|
||||
@@ -116,7 +121,7 @@
|
||||
#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1
|
||||
|
||||
// merge transformation use magic number division
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 1
|
||||
#define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0
|
||||
|
||||
// hack: have underlying assumption that need to be satsified, otherwise it's a bug
|
||||
// hack for forcing register to keep idx_diff_low_const in SGPR. idx_diff_low_const must be
|
||||
|
||||
@@ -94,7 +94,7 @@ __host__ __device__ constexpr auto container_reorder_given_old2new(Sequence<Is..
|
||||
|
||||
constexpr auto new2old = typename sequence_map_inverse<Sequence<IRs...>>::type{};
|
||||
|
||||
return container_reorder_give_new2old(old_seq, new2old);
|
||||
return container_reorder_given_new2old(old_seq, new2old);
|
||||
}
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_275126
|
||||
@@ -223,6 +223,13 @@ container_reverse_exclusive_scan(const Array<TData, NSize>& x, Reduce f, TData i
|
||||
return y;
|
||||
}
|
||||
|
||||
template <index_t... Is, typename Reduce, index_t Init>
|
||||
__host__ __device__ constexpr auto
|
||||
container_reverse_exclusive_scan(const Sequence<Is...>& seq, Reduce f, Number<Init>)
|
||||
{
|
||||
return reverse_exclusive_scan_sequence(seq, f, Number<Init>{});
|
||||
}
|
||||
|
||||
#if !CK_WORKAROUND_SWDEV_275126
|
||||
// rocm4.1 compiler would crash with recursive lambda
|
||||
template <typename... Xs, typename Reduce, typename Init>
|
||||
@@ -366,6 +373,19 @@ set_container_subset(Tuple<Ys...>& y, Sequence<Is...> picks, const Tuple<Xs...>&
|
||||
static_for<0, sizeof...(Is), 1>{}([&](auto i) { y(picks[i]) = x[i]; });
|
||||
}
|
||||
|
||||
template <typename Container>
|
||||
__host__ __device__ constexpr auto to_tuple_of_number(const Container&)
|
||||
{
|
||||
static_assert(is_known_at_compile_time<Container>::value, "wrong!");
|
||||
|
||||
return generate_tuple(
|
||||
[&](auto i) {
|
||||
constexpr index_t tmp = Container::At(i);
|
||||
return Number<tmp>{};
|
||||
},
|
||||
Container::Size());
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_to_tuple_of_number(Sequence<Is...>)
|
||||
{
|
||||
|
||||
@@ -100,39 +100,72 @@ struct DynamicBuffer
|
||||
*reinterpret_cast<X*>(&p_data_[i]) = x;
|
||||
#else
|
||||
// HACK: compiler would lower IR "store<i8, 16> address_space(3)" into inefficient
|
||||
// ISA, so I try to let compiler emit use IR "store<i32, 4>" which would be lower to
|
||||
// ISA, so I try to let compiler emit IR "store<i32, 4>" which would be lower to
|
||||
// ds_write_b128
|
||||
// TODO: remove this after compiler fix
|
||||
if constexpr(is_same<typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type,
|
||||
int8_t>::value)
|
||||
{
|
||||
static_assert(
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value) ||
|
||||
(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value),
|
||||
"wrong! not implemented for this combination, please add implementation");
|
||||
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int8_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int8_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x2_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int16_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int16_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32_t*>(&x);
|
||||
}
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x4_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x4_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32_t*>(&x);
|
||||
}
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x8_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x8_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
*reinterpret_cast<int32x2_t*>(&p_data_[i]) =
|
||||
*reinterpret_cast<const int32x2_t*>(&x);
|
||||
}
|
||||
if constexpr(is_same<remove_cv_t<remove_reference_t<T>>, int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
|
||||
else if constexpr(is_same<remove_cv_t<remove_reference_t<T>>,
|
||||
int8x16_t>::value &&
|
||||
is_same<remove_cv_t<remove_reference_t<X>>, int8x16_t>::value)
|
||||
{
|
||||
// HACK: cast pointer of x is bad
|
||||
// TODO: remove this after compiler fix
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
#include "device_tensor.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp"
|
||||
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp"
|
||||
@@ -24,23 +26,27 @@
|
||||
#define USE_DYNAMIC_MODE 1
|
||||
#define USE_CONV_FWD_V4R4_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4R2_NHWC 0
|
||||
#define USE_CONV_FWD_V4R5_NCHW 0
|
||||
#define USE_CONV_FWD_V4R5R2_NCHW 1
|
||||
#define USE_CONV_FWD_V5R1_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4_XDL_NCHW 0
|
||||
#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1
|
||||
#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0
|
||||
#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0
|
||||
|
||||
enum ConvForwardAlgo
|
||||
{
|
||||
V4R4NCHW, // 0
|
||||
V4R4NHWC, // 1
|
||||
V4R5NCHW, // 2
|
||||
V5R1NCHW, // 3
|
||||
V4R4XDLNCHW, // 4
|
||||
V4R4R2XDLNHWC, // 5
|
||||
V4R4R3XDLNHWC, // 6
|
||||
V4R4R4XDLNHWC // 7
|
||||
V4R4R2NHWC, // 2
|
||||
V4R5NCHW, // 3
|
||||
V4R5R2NCHW, // 4
|
||||
V5R1NCHW, // 5
|
||||
V4R4XDLNCHW, // 6
|
||||
V4R4R2XDLNHWC, // 7
|
||||
V4R4R3XDLNHWC, // 8
|
||||
V4R4R4XDLNHWC // 9
|
||||
};
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
@@ -132,21 +138,18 @@ int main(int argc, char* argv[])
|
||||
const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#if 1
|
||||
using in_data_t = float;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = float;
|
||||
#elif 1
|
||||
constexpr index_t in_vector_size = 1;
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
using in_data_t = half_t;
|
||||
using acc_data_t = float;
|
||||
using out_data_t = half_t;
|
||||
#elif 1
|
||||
constexpr index_t in_vector_size = 16;
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
using in_data_t = int8_t;
|
||||
using acc_data_t = int32_t;
|
||||
using out_data_t = int8_t;
|
||||
#endif
|
||||
|
||||
std::vector<std::size_t> in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4);
|
||||
@@ -348,6 +351,33 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R4R2_NHWC
|
||||
if(algo == ConvForwardAlgo::V4R4R2NHWC)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NHWC)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nhwc();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R5_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R5NCHW)
|
||||
{
|
||||
@@ -374,6 +404,33 @@ int main(int argc, char* argv[])
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V4R5R2_NCHW
|
||||
if(algo == ConvForwardAlgo::V4R5R2NCHW)
|
||||
{
|
||||
if(layout != ConvTensorLayout::NCHW)
|
||||
{
|
||||
throw std::runtime_error("wrong! layout");
|
||||
}
|
||||
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw<in_data_t,
|
||||
acc_data_t,
|
||||
out_data_t>(
|
||||
tmp[I0],
|
||||
tmp[I1],
|
||||
tmp[I2],
|
||||
tmp[I3],
|
||||
tmp[I4],
|
||||
tmp[I5],
|
||||
tmp[I6],
|
||||
in,
|
||||
wei,
|
||||
out_device,
|
||||
nrepeat);
|
||||
}
|
||||
#endif
|
||||
|
||||
#if USE_CONV_FWD_V5R1_NCHW
|
||||
if(algo == ConvForwardAlgo::V5R1NCHW)
|
||||
{
|
||||
@@ -385,7 +442,7 @@ int main(int argc, char* argv[])
|
||||
const auto tmp = f_make_for_device_nchw();
|
||||
|
||||
device_dynamic_convolution_forward_implicit_gemm_v5r1_nchw_kcyx_nkhw<in_data_t,
|
||||
in_vector_size,
|
||||
16,
|
||||
acc_data_t,
|
||||
out_data_t>(tmp[I0],
|
||||
tmp[I1],
|
||||
@@ -525,10 +582,10 @@ int main(int argc, char* argv[])
|
||||
#if 0
|
||||
if(do_log)
|
||||
{
|
||||
LogRange(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRange(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "in : ", in.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "wei: ", wei.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_host : ", out_host.mData, ",") << std::endl;
|
||||
LogRangeAsType<float>(std::cout << "out_device: ", out_device.mData, ",") << std::endl;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -55,7 +55,7 @@ float launch_and_time_kernel(F kernel,
|
||||
{
|
||||
KernelTimer timer;
|
||||
|
||||
printf("%s: block_dim {%d, %d, %d}, grid_dim {%d, %d, %d} \n",
|
||||
printf("%s: grid_dim {%d, %d, %d}, block_dim {%d, %d, %d} \n",
|
||||
__func__,
|
||||
grid_dim.x,
|
||||
grid_dim.y,
|
||||
|
||||
@@ -0,0 +1,292 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp"
|
||||
#include "driver_dynamic_gemm_v1r3.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_nhwc_kyxc_nhwk(
|
||||
const InLengths& in_n_hi_wi_c_lengths,
|
||||
const WeiLengths& wei_k_y_x_c_lengths,
|
||||
const OutLengths& out_n_ho_wo_k_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_hi_wi_c,
|
||||
const Tensor<TInWei>& wei_k_y_x_c,
|
||||
Tensor<TOut>& out_n_ho_wo_k,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
|
||||
DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace());
|
||||
|
||||
in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data());
|
||||
wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data());
|
||||
out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data());
|
||||
|
||||
const auto in_n_hi_wi_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_hi_wi_c_lengths);
|
||||
const auto wei_k_y_x_c_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_y_x_c_lengths);
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
// [M, N, K0, K1] = [128, 128, 8, 1] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 1;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 1>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 8, 2] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 2;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 2>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 2>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 8, 4] for i8
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlockM1 = 128;
|
||||
constexpr index_t GemmNPerBlockN1 = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1 = Sequence<1, 1, 1, 4>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1 = Sequence<2, 1, 128, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1 = Sequence<4, 1, 1, 4>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1 = Sequence<1, 1, 1, 4>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_N11 = 4;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad(in_n_hi_wi_c_desc,
|
||||
wei_k_y_x_c_desc,
|
||||
out_n_ho_wo_k_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GemmK1>{});
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0];
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1];
|
||||
const auto out_gemmm_gemmn_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GemmM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}), // 3+: GemmK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GemmK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmN1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{})); // 3-: GemmK1
|
||||
|
||||
constexpr auto out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2+: GemmM11
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3+: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4+: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 5+: GemmN11
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmM0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM10
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 2-: GemmM11
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 3-: GemmN0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 4-: GemmN10
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 5-: GemmN11
|
||||
|
||||
constexpr auto in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_v1r3<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(in_gemmk0_gemmm_gemmk1_grid_desc),
|
||||
decltype(wei_gemmk0_gemmn_gemmk1_grid_desc),
|
||||
decltype(out_gemmm_gemmn_grid_desc),
|
||||
GemmMPerBlockM1,
|
||||
GemmNPerBlockN1,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_K0_M0_M1_K1,
|
||||
GemmABlockTransferThreadClusterLengths_K0_M0_M1_K1,
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferSrcAccessOrder
|
||||
GemmABlockTransferSrcVectorTensorLengths_K0_M0_M1_K1,
|
||||
Sequence<1, 2, 0, 3>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmABlockTransferDstVectorTensorLengths_K0_M0_M1_K1,
|
||||
GemmBBlockTransferThreadSliceLengths_K0_N0_N1_K1,
|
||||
GemmBBlockTransferThreadClusterLengths_K0_N0_N1_K1,
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferSrcAccessOrder
|
||||
GemmBBlockTransferSrcVectorTensorLengths_K0_N0_N1_K1,
|
||||
Sequence<1, 2, 0, 3>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmBBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
|
||||
Sequence<0, 1, 2, 3, 4, 5>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_N11,
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks),
|
||||
decltype(out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks),
|
||||
decltype(in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks),
|
||||
decltype(wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(in_n_hi_wi_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(wei_k_y_x_c_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_ho_wo_k_device_buf.GetDeviceBuffer()),
|
||||
in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_iterator_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_iterator_hacks,
|
||||
out_gemmm0_gemmm10_gemmm11_gemmn0_gemmn10_gemmn11_grid_iterator_hacks,
|
||||
in_gemmk0_gemmm0_gemmm1_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
wei_gemmk0_gemmn0_gemmn1_gemmk1_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
{
|
||||
const auto N = out_n_ho_wo_k_lengths[I0];
|
||||
const auto K = out_n_ho_wo_k_lengths[I3];
|
||||
const auto C = wei_k_y_x_c_lengths[I3];
|
||||
|
||||
const auto Hi = in_n_hi_wi_c_lengths[I1];
|
||||
const auto Wi = in_n_hi_wi_c_lengths[I2];
|
||||
|
||||
const auto Ho = out_n_ho_wo_k_lengths[I1];
|
||||
const auto Wo = out_n_ho_wo_k_lengths[I2];
|
||||
|
||||
const auto Y = wei_k_y_x_c_lengths[I1];
|
||||
const auto X = wei_k_y_x_c_lengths[I2];
|
||||
|
||||
float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s"
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data());
|
||||
}
|
||||
@@ -275,12 +275,12 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 6-: M2
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 7-: N1
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{};
|
||||
|
||||
constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_gemm_xdlops_v2r3<
|
||||
|
||||
@@ -0,0 +1,249 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "transform_forward_convolution_into_gemm_v4r5r2_nchw_kcyx_nkhw.hpp"
|
||||
#include "driver_dynamic_contraction_v1r2.hpp"
|
||||
|
||||
template <typename TInWei,
|
||||
typename TAcc,
|
||||
typename TOut,
|
||||
typename InLengths,
|
||||
typename WeiLengths,
|
||||
typename OutLengths,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
const InLengths& in_n_c_hi_wi_lengths,
|
||||
const WeiLengths& wei_k_c_y_x_lengths,
|
||||
const OutLengths& out_n_k_ho_wo_lengths,
|
||||
const ConvStrides& conv_strides,
|
||||
const ConvDilations& conv_dilations,
|
||||
const InLeftPads& in_left_pads,
|
||||
const InRightPads& in_right_pads,
|
||||
const Tensor<TInWei>& in_n_c_hi_wi,
|
||||
const Tensor<TInWei>& wei_k_c_y_x,
|
||||
Tensor<TOut>& out_n_k_ho_wo,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
std::cout << __func__ << std::endl;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
|
||||
DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace());
|
||||
DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace());
|
||||
DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace());
|
||||
|
||||
in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data());
|
||||
wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data());
|
||||
out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data());
|
||||
|
||||
const auto in_n_c_hi_wi_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(in_n_c_hi_wi_lengths);
|
||||
const auto wei_k_c_y_x_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(wei_k_c_y_x_lengths);
|
||||
const auto out_n_k_ho_wo_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
||||
|
||||
#if 1
|
||||
// [8, 1, 128, 1] * [8, 4, 32, 1] = [1, 128, 4, 32] for fp32
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 1;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 32;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 1>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#elif 1
|
||||
// [8, 1, 128, 2] * [8, 4, 32, 2] = [1, 128, 4, 32] for fp16
|
||||
// cdata = 64, BlockSize = 256
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GN0 = 4;
|
||||
constexpr index_t GK1 = 2;
|
||||
|
||||
constexpr index_t GemmGM1PerBlockGM11 = 128;
|
||||
constexpr index_t GemmGN1PerBlockGN11 = 32;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmM1PerThreadM111 = 4;
|
||||
constexpr index_t GemmN1PerThreadN111 = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmM11N11ThreadClusterM1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1101 = 2;
|
||||
constexpr index_t GemmM11N11ThreadClusterM1100 = 8;
|
||||
constexpr index_t GemmM11N11ThreadClusterN1100 = 8;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 2>;
|
||||
using GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<2, 1, 1, 128, 1>;
|
||||
|
||||
using GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<4, 1, 1, 1, 1>;
|
||||
using GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 4, 1, 1, 2>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<8, 1, 1, 32, 1>;
|
||||
|
||||
using GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 1>;
|
||||
using GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1 = Sequence<1, 1, 1, 1, 2>;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector_BN1 = 1;
|
||||
#endif
|
||||
|
||||
const auto descs =
|
||||
transform_forward_convolution_into_contraction_v4r5r2_nchw_kcyx_nkhw_pad(wei_k_c_y_x_desc,
|
||||
in_n_c_hi_wi_desc,
|
||||
out_n_k_ho_wo_desc,
|
||||
conv_strides,
|
||||
conv_dilations,
|
||||
in_left_pads,
|
||||
in_right_pads,
|
||||
Number<GN0>{},
|
||||
Number<GK1>{});
|
||||
|
||||
const auto wei_gk0_gm0_gm1_gk1_grid_desc = descs[I0];
|
||||
const auto in_gk0_gn0_gn1_gk1_grid_desc = descs[I1];
|
||||
const auto out_gm0_gm1_gn0_gn1_grid_desc = descs[I2];
|
||||
|
||||
// HACK: hacks that control index calculation when iterating over A, B, C matrix
|
||||
constexpr auto wei_grid_iterator_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1+: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3+: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 1-: GM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 3-: GM11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto in_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 3+: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 4+: GK1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GK0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 3-: GN11
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 4-: GK1
|
||||
|
||||
constexpr auto out_grid_iterator_hacks = make_tuple(
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 1+: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 2+: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 4+: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 5+: GN1
|
||||
make_tuple(
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: GM10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 1-: BM0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 2-: BM1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: GN10
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 4-: BN0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 5-: GN1
|
||||
|
||||
constexpr auto wei_grid_move_slice_window_iterator_hacks = Sequence<0, 0, 0, 0, 0, 0, 0>{};
|
||||
|
||||
constexpr auto in_grid_move_slice_window_iterator_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
float ave_time = driver_dynamic_contraction_v1r2<
|
||||
BlockSize,
|
||||
TInWei,
|
||||
TAcc,
|
||||
TOut,
|
||||
InMemoryDataOperation::Set,
|
||||
decltype(wei_gk0_gm0_gm1_gk1_grid_desc),
|
||||
decltype(in_gk0_gn0_gn1_gk1_grid_desc),
|
||||
decltype(out_gm0_gm1_gn0_gn1_grid_desc),
|
||||
GemmGM1PerBlockGM11,
|
||||
GemmGN1PerBlockGN11,
|
||||
GemmKPerBlock,
|
||||
GemmM1PerThreadM111,
|
||||
GemmN1PerThreadN111,
|
||||
GemmKPerThread,
|
||||
GemmM11N11ThreadClusterM1100,
|
||||
GemmM11N11ThreadClusterN1100,
|
||||
GemmM11N11ThreadClusterM1101,
|
||||
GemmM11N11ThreadClusterN1101,
|
||||
GemmABlockTransferThreadSliceLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
GemmABlockTransferThreadClusterLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<1, 2, 3, 0, 4>, // ABlockTransferThreadClusterArrangeOrder
|
||||
Sequence<3, 2, 1, 0, 4>, // ABlockTransferSrcAccessOrder
|
||||
GemmABlockTransferSrcVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
GemmABlockTransferDstVectorTensorLengths_GK0_GM0_GM10_GM11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // ABlockTransferSrcVectorTensorContiguousDimOrder
|
||||
GemmBBlockTransferThreadSliceLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
GemmBBlockTransferThreadClusterLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 4, 1, 2, 3>, // BBlockTransferThreadClusterArrangeOrder
|
||||
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
|
||||
GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_BN1,
|
||||
decltype(wei_grid_iterator_hacks),
|
||||
decltype(in_grid_iterator_hacks),
|
||||
decltype(out_grid_iterator_hacks),
|
||||
decltype(wei_grid_move_slice_window_iterator_hacks),
|
||||
decltype(in_grid_move_slice_window_iterator_hacks)>(
|
||||
static_cast<TInWei*>(wei_k_c_y_x_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TInWei*>(in_n_c_hi_wi_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TOut*>(out_n_k_ho_wo_device_buf.GetDeviceBuffer()),
|
||||
wei_gk0_gm0_gm1_gk1_grid_desc,
|
||||
in_gk0_gn0_gn1_gk1_grid_desc,
|
||||
out_gm0_gm1_gn0_gn1_grid_desc,
|
||||
wei_grid_iterator_hacks,
|
||||
in_grid_iterator_hacks,
|
||||
out_grid_iterator_hacks,
|
||||
wei_grid_move_slice_window_iterator_hacks,
|
||||
in_grid_move_slice_window_iterator_hacks,
|
||||
nrepeat);
|
||||
|
||||
float perf = (float)calculate_convolution_flops(
|
||||
in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
// copy result back to host
|
||||
out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data());
|
||||
}
|
||||
Reference in New Issue
Block a user