mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 01:36:06 +00:00
Refactor for MIOpen integration (#4)
Refactor, so can bring multi-index transformation and padding support into MIOpen
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
#ifndef CK_CONVOLUTION_COMMON_HPP
|
||||
#define CK_CONVOLUTION_COMMON_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum ConvolutionDirection
|
||||
{
|
||||
Forward,
|
||||
BackwardData,
|
||||
BackwardWeight
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_DIRECT_V2_NCHW_KCYX_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R1_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R2_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_3d_tensor_op.hpp"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
@@ -125,38 +125,38 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
3,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
InBlockCopyDataPerAccess_N>({0, 0, 0, 0},
|
||||
{0, 0, 0, 0});
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_CHWN,
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
Sequence<0, 1, 2, 3>,
|
||||
3,
|
||||
3,
|
||||
InBlockCopyDataPerAccess_N,
|
||||
InBlockCopyDataPerAccess_N>({0, 0, 0, 0}, {0, 0, 0, 0});
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_CK,
|
||||
WeiBlockCopyClusterLengths_CK,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, 0}, {0, 0});
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_CK,
|
||||
WeiBlockCopyClusterLengths_CK,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
Sequence<0, 1>,
|
||||
1,
|
||||
1,
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, 0},
|
||||
{0, 0});
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -318,14 +318,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
|
||||
decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
|
||||
make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
@@ -388,14 +389,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
|
||||
decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
|
||||
make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
@@ -127,9 +127,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
auto blockwise_in_copy =
|
||||
#if 0
|
||||
BlockwiseGenericTensorSliceCopy_v1
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
#else
|
||||
BlockwiseGenericTensorSliceCopy_v2
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
#endif
|
||||
<BlockSize,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
@@ -149,9 +149,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
// format is [CPerBlock, X * KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
#if 0
|
||||
BlockwiseGenericTensorSliceCopy_v1
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
#else
|
||||
BlockwiseGenericTensorSliceCopy_v2
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
#endif
|
||||
<BlockSize,
|
||||
decltype(wei_c_k_global_desc),
|
||||
@@ -406,14 +406,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
|
||||
decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
|
||||
make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
@@ -476,14 +477,15 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn_lds_double_buffer
|
||||
n_block_data_begin + n_thread_data_begin);
|
||||
|
||||
#if 1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(
|
||||
make_zero_array<index_t, 10>(), make_zero_array<index_t, 10>())
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<
|
||||
decltype(out_10d_thread_desc),
|
||||
decltype(out_10d_global_desc),
|
||||
decltype(out_10d_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 10, 1>::type,
|
||||
9,
|
||||
OutThreadCopyDataPerAccess_N,
|
||||
OutThreadCopyDataPerAccess_N>(make_zero_array<index_t, 10>(),
|
||||
make_zero_array<index_t, 10>())
|
||||
.Run(p_out_thread, p_out_thread_on_global);
|
||||
#elif 0
|
||||
ThreadwiseGenericTensorSliceCopy_v1r1<decltype(out_10d_thread_desc),
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_CHWN_CYXK_KHWN_PADDED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_tensor_slice_copy.hpp"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V1R3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
#include "blockwise_tensor_slice_copy.hpp"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V2_CHWN_CYXK_KHWN_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_4d_tensor_op.hpp"
|
||||
#include "blockwise_2d_tensor_op.hpp"
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
@@ -128,7 +128,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_n1_b_n2_global_merged_desc),
|
||||
@@ -155,20 +155,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
Sequence<0, 1>, // thread_arrange_order [C, K]
|
||||
Sequence<0, 1>, // src_access_order [C, K]
|
||||
Sequence<0, 1>, // dst_access_order [C, K]
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
Sequence<0, 1>, // thread_arrange_order [C, K]
|
||||
Sequence<0, 1>, // src_access_order [C, K]
|
||||
Sequence<0, 1>, // dst_access_order [C, K]
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V3_NCHW_CYXK_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
@@ -125,7 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
|
||||
const auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_c_n1_b_n2_global_merged_desc),
|
||||
@@ -152,20 +152,19 @@ struct GridwiseConvolutionImplicitGemm_v3_nchw_cyxk_nkhw_lds_double_buffer
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
const auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
Sequence<0, 1>, // thread_arrange_order [C, K]
|
||||
Sequence<0, 1>, // src_access_order [C, K]
|
||||
Sequence<0, 1>, // dst_access_order [C, K]
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
const auto blockwise_wei_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_C_K,
|
||||
WeiBlockCopyClusterLengths_C_K,
|
||||
Sequence<0, 1>, // thread_arrange_order [C, K]
|
||||
Sequence<0, 1>, // src_access_order [C, K]
|
||||
Sequence<0, 1>, // dst_access_order [C, K]
|
||||
WeiBlockCopyDataPerAccess_K,
|
||||
WeiBlockCopyDataPerAccess_K>({0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
|
||||
@@ -2,24 +2,71 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "convolution_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <ConvolutionDirection>
|
||||
struct make_wei_e_k_global_desc_v4r1;
|
||||
|
||||
template <>
|
||||
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::Forward>
|
||||
{
|
||||
template <typename WeiDesc>
|
||||
__device__ constexpr auto operator()(WeiDesc) const
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(WeiDesc{}, I1, I3), Sequence<1, 0>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct make_wei_e_k_global_desc_v4r1<ConvolutionDirection::BackwardWeight>
|
||||
{
|
||||
template <typename WeiDesc>
|
||||
__device__ constexpr auto operator()(WeiDesc) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiDesc{};
|
||||
|
||||
constexpr index_t K = wei_k_c_y_x_global_desc.GetLength(I0);
|
||||
constexpr index_t C = wei_k_c_y_x_global_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I2, I3),
|
||||
make_tuple(Merge<Sequence<C, Y * X>>{}, PassThrough<K>{}),
|
||||
make_tuple(Sequence<1, 2>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
typename Float,
|
||||
typename AccDataType,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
ConvolutionDirection ConvDirection,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
@@ -33,18 +80,18 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
typename InBlockCopySubLengths_E_N1_B_N2,
|
||||
typename InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
@@ -53,6 +100,22 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
constexpr auto global_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::global>{};
|
||||
|
||||
static_assert(ConvDirection == ConvolutionDirection::Forward ||
|
||||
ConvDirection == ConvolutionDirection::BackwardWeight,
|
||||
"wrong! this kernel only support convolution forward and backward-weight");
|
||||
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
@@ -63,24 +126,18 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
@@ -106,46 +163,51 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
// global tensor in global memory
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<3, 6, 7>{},
|
||||
Sequence<5>{});
|
||||
// global tensor in global memory, src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
@@ -154,12 +216,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
static_assert(in_e_n1_b_n2_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not satisfied");
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
// input tensor blockwise copy
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
@@ -174,21 +234,27 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
// global tensor in global memory, src of blockwise copy
|
||||
// It is constructed differently, depending on whether forward or backward weight
|
||||
// convolution
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
make_wei_e_k_global_desc_v4r1<ConvDirection>{}(wei_k_c_y_x_global_desc);
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// block tensor in LDS memory, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// weight tensor blockwise copy
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
@@ -204,15 +270,18 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx[EPerBlock, KPerBlock] is in LDS
|
||||
// b_mtx[EPerBlocl, N1 * BPerBlock * N2] is in LDS
|
||||
// c_mtx[KPerBlock, N1 * BPerBlock * N2] is distributed among threads, and saved in
|
||||
// register
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
in_e_n1_b_n2_block_desc.GetLength(I0),
|
||||
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
|
||||
in_e_n1_b_n2_block_desc.GetLength(I3),
|
||||
in_e_n1_b_n2_block_desc.GetStride(I0));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
@@ -258,17 +327,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_block_double);
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global, p_in_block_double, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
@@ -299,12 +368,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -317,60 +384,84 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
// even iteration
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
|
||||
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
__syncthreads();
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0, 0, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
// LDS double buffer: load last data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
// LDS double buffer: store last data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / K1;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
|
||||
// define output tensor descriptor for threadwise copy
|
||||
// thread output tensor, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
|
||||
|
||||
// output memory layout descriptor in device memory
|
||||
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
// global output tensor
|
||||
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
|
||||
UnMerge<Sequence<K0, K1>>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}, Sequence<6>{}));
|
||||
|
||||
// output merged global tensor descriptor, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{});
|
||||
// global output tensor, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_n2_k0_k1_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K0>{},
|
||||
PassThrough<K1>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -383,26 +474,23 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.template Run<Float, Float, address_space_t::generic, address_space_t::global>(
|
||||
p_out_thread, p_out_global);
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_desc),
|
||||
decltype(
|
||||
out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
#endif
|
||||
|
||||
@@ -1,27 +1,58 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "convolution_common.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <ConvolutionDirection>
|
||||
struct make_wei_e_k_global_desc_v4r1_deprecated;
|
||||
|
||||
template <>
|
||||
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::Forward>
|
||||
{
|
||||
template <typename WeiDesc>
|
||||
__device__ constexpr auto operator()(WeiDesc) const
|
||||
{
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return WeiDesc::Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct make_wei_e_k_global_desc_v4r1_deprecated<ConvolutionDirection::BackwardWeight>
|
||||
{
|
||||
template <typename WeiDesc>
|
||||
__device__ constexpr auto operator()(WeiDesc) const
|
||||
{
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
return make_ConstantMergedTensorDescriptor(
|
||||
WeiDesc::Unfold(I2, I3), Sequence<1, 2>{}, Sequence<0>{});
|
||||
}
|
||||
};
|
||||
|
||||
// define B = merge(N0, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
class Float,
|
||||
class AccDataType,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
ConvolutionDirection ConvDirection,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
@@ -35,26 +66,42 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename InBlockCopySubLengths_E_N1_B_N2,
|
||||
typename InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
class InBlockCopySubLengths_E_N1_B_N2,
|
||||
class InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopySrcDataPerRead_B,
|
||||
index_t InBlockCopyDstDataPerWrite_N2,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buffer
|
||||
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer_deprecated
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
constexpr auto global_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::global>{};
|
||||
|
||||
static_assert(ConvDirection == ConvolutionDirection::Forward ||
|
||||
ConvDirection == ConvolutionDirection::BackwardWeight,
|
||||
"wrong! this kernel only support convolution forward and backward-weight");
|
||||
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
@@ -65,25 +112,16 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
0,
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_ho_wo_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
@@ -116,43 +154,39 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// global memory
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
// tensor descriptor in device memory [N0, N1, N2, Ho, Wo]
|
||||
constexpr auto in_n0_n1_n2_h_w_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I3, Number<Wo>{}, Number<ConvStrideW>{})
|
||||
.Fold(I0, Number<N1>{}, Number<N2>{})
|
||||
.Extract(Sequence<0, 1, 2, 4, 5>{});
|
||||
|
||||
constexpr auto in_n0_n1_n2_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3>{}, Sequence<4, 5>{}, Sequence<6, 7>{}));
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
constexpr auto in_e_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
in_n0_n1_n2_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3, 4, 6>{}, Sequence<1>{}, Sequence<0, 5, 7>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
// merged tensor descriptor in device memory [E, N1, B, N2], src of blockwise copy
|
||||
constexpr auto in_e_n1_b_n2_global_merged_desc = make_ConstantMergedTensorDescriptor(
|
||||
in_c_y_x_global_desc.Embed(in_n0_n1_n2_h_w_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<4>{},
|
||||
Sequence<3, 6, 7>{},
|
||||
Sequence<5>{});
|
||||
|
||||
// memory layout descriptor in LDS [E, N1, B, N2], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_native_tensor_descriptor_aligned(
|
||||
constexpr auto in_e_n1_b_n2_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, N1, BPerBlock, N2>{}, Number<InBlockCopyDstDataPerWrite_N2>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
@@ -164,56 +198,51 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
2,
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>(
|
||||
{0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v2_deprecated<
|
||||
BlockSize,
|
||||
decltype(in_e_n1_b_n2_global_merged_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc),
|
||||
decltype(in_e_n1_b_n2_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
2,
|
||||
3,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2>({0, 0, b_block_data_on_global, 0}, {0, 0, 0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
// Iensor descriptor in device memory, src of blockwise copy
|
||||
// It is constructed differently, depending on whether forward or backward weight
|
||||
// convolution
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
make_wei_e_k_global_desc_v4r1_deprecated<ConvDirection>{}(wei_k_c_y_x_global_desc);
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with multiple alignment
|
||||
// requirements
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
@@ -224,11 +253,8 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
// register
|
||||
constexpr auto a_e_k_block_mtx_desc = make_ConstantMatrixDescriptor(wei_e_k_block_desc);
|
||||
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
in_e_n1_b_n2_block_desc.GetLength(I0),
|
||||
in_e_n1_b_n2_block_desc.GetLength(I1) * in_e_n1_b_n2_block_desc.GetLength(I2) *
|
||||
in_e_n1_b_n2_block_desc.GetLength(I3),
|
||||
in_e_n1_b_n2_block_desc.GetStride(I0));
|
||||
constexpr auto b_e_n1bn2_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(in_e_n1_b_n2_block_desc.Unfold(I1, I3));
|
||||
|
||||
// sanity check
|
||||
static_assert(KPerBlock % (GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster) ==
|
||||
@@ -240,14 +266,14 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
|
||||
// c_thread_mtx definition: this is a mess
|
||||
// TODO:: more elegent way of defining c_thread_mtx
|
||||
constexpr auto c_k0k2_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<N1 * N2>{});
|
||||
constexpr auto c_k0k1_n1n2_thread_mtx_desc = make_ConstantMatrixDescriptor_packed(
|
||||
Number<GemmMRepeat * GemmMPerThreadSubC>{}, Number<GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2<
|
||||
BlockSize,
|
||||
decltype(a_e_k_block_mtx_desc),
|
||||
decltype(b_e_n1bn2_block_mtx_desc),
|
||||
decltype(c_k0k2_n1n2_thread_mtx_desc),
|
||||
decltype(c_k0k1_n1n2_thread_mtx_desc),
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
@@ -274,17 +300,17 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register allocation for output
|
||||
Float p_out_thread[c_k0k2_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
AccDataType p_out_thread[c_k0k1_n1n2_thread_mtx_desc.GetElementSpace()];
|
||||
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k2_n1n2_thread_mtx_desc, p_out_thread);
|
||||
threadwise_matrix_set_zero(c_k0k1_n1n2_thread_mtx_desc, p_out_thread);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_block_double);
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global, p_in_block_double, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
@@ -315,12 +341,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -343,10 +367,10 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
@@ -369,38 +393,24 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t K0 = K / K1;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output memory layout descriptor in register, src of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_desc = make_native_tensor_descriptor_packed(
|
||||
constexpr auto out_k0_k1_n1_b_n2_thread_mem_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, N1, 1, N2>{});
|
||||
|
||||
// output memory layout descriptor in device memory
|
||||
constexpr auto out_n0_n1_n2_k0_k1_ho_wo_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(UnMerge<Sequence<N0, N1, N2>>{},
|
||||
UnMerge<Sequence<K0, K1>>{},
|
||||
PassThrough<Ho>{},
|
||||
PassThrough<Wo>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0, 1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}, Sequence<6>{}));
|
||||
constexpr auto out_n0_n1_n2_k0_k1_h_w_global_mem_desc =
|
||||
out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}).Fold(I0, Number<N1>{}, Number<N2>{});
|
||||
|
||||
// output merged global tensor descriptor, dst of threadwise copy
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_desc = transform_tensor_descriptor(
|
||||
out_n0_n1_n2_k0_k1_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K0>{},
|
||||
PassThrough<K1>{},
|
||||
PassThrough<N1>{},
|
||||
Merge<Sequence<N0, Ho, Wo>>{},
|
||||
PassThrough<N2>{}),
|
||||
make_tuple(Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
constexpr auto out_k0_k1_n1_b_n2_global_merged_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n0_n1_n2_k0_k1_h_w_global_mem_desc,
|
||||
Sequence<3>{},
|
||||
Sequence<4>{},
|
||||
Sequence<1>{},
|
||||
Sequence<0, 5, 6>{},
|
||||
Sequence<2>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
@@ -413,31 +423,25 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col / N2;
|
||||
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<decltype(out_k0_k1_n1_b_n2_thread_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_desc),
|
||||
decltype(
|
||||
out_k0_k1_n1_b_n2_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
#if 1
|
||||
.template Run<Float, Float, address_space_t::generic, address_space_t::global>
|
||||
#else // tweaking
|
||||
.template Run_optimized_dst_address_calculation<Float,
|
||||
Float,
|
||||
address_space_t::generic,
|
||||
address_space_t::global>
|
||||
#endif
|
||||
(p_out_thread, p_out_global);
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_global_merged_desc),
|
||||
decltype(out_k0_k1_n1_b_n2_thread_mem_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
arithmetic_sequence_gen<0, 5, 1>::type,
|
||||
3,
|
||||
3,
|
||||
1,
|
||||
1>({0, 0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
0,
|
||||
b_thread_data_on_global,
|
||||
0})
|
||||
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_DEPRECATED_HPP
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R2_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
@@ -166,7 +166,7 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n0_ho0_wo0_b_n2_ho2_wo2_global_merged_desc),
|
||||
@@ -196,18 +196,18 @@ struct GridwiseConvolutionImplicitGemm_v4r2_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R3_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
@@ -165,7 +165,7 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1<
|
||||
auto blockwise_in_copy = BlockwiseGenericTensorSliceCopy_v1_deprecated<
|
||||
BlockSize,
|
||||
Float,
|
||||
decltype(in_e_n1_ho1_wo1_b_n2_ho2_wo2_global_merged_desc),
|
||||
@@ -195,18 +195,18 @@ struct GridwiseConvolutionImplicitGemm_v4r3_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated<BlockSize,
|
||||
Float,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
#if 0
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_HPP
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
typename Float,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
@@ -32,17 +34,17 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopySubLengths_E_B,
|
||||
class InBlockCopyClusterLengths_E_B,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
typename InBlockCopySubLengths_E_B,
|
||||
typename InBlockCopyClusterLengths_E_B,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopyDataPerAccess_B,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K,
|
||||
index_t OutThreadCopyDataPerAccess_B>
|
||||
@@ -56,23 +58,32 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
constexpr auto global_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::global>{};
|
||||
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_ho_wo_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
@@ -90,50 +101,52 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
"be violated");
|
||||
|
||||
// divide block work by [K, B]
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % (2 * EPerBlock) == 0,
|
||||
static_assert(K % KPerBlock == 0 && B % BPerBlock == 0 && E % EPerBlock == 0,
|
||||
"wrong! cannot divide work evenly among block");
|
||||
|
||||
constexpr index_t KBlockWork = K / KPerBlock;
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// tensor descriptor in device memory [N, Ho, Wo]
|
||||
constexpr auto in_n_ho_wo_global_desc =
|
||||
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
|
||||
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
|
||||
// global mem
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
|
||||
// merged tensor descriptor in device memory [E, B], src of blockwise copy
|
||||
constexpr auto in_e_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<3, 4, 5>{});
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// memory layout descriptor in LDS [E, B], dst of blockwise copy
|
||||
// LDS mem
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_b_block_desc =
|
||||
make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{});
|
||||
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
@@ -149,13 +162,13 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
// global mem
|
||||
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// LDS
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
@@ -165,11 +178,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
// weight blockwise copy
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v2<BlockSize,
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
@@ -247,14 +258,12 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.template Run<Float, address_space_t::global>(p_in_global,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, address_space_t::global>(p_wei_global,
|
||||
p_wei_block_double);
|
||||
blockwise_in_copy.Run(
|
||||
p_in_global, p_in_block_double, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global, p_wei_block_double, global_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
@@ -285,10 +294,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -301,50 +310,51 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
constexpr bool has_two_iteration_left = (E % (2 * EPerBlock) == 0);
|
||||
|
||||
// even iteration
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
if(has_two_iteration_left) // if has 2 iteration left
|
||||
{
|
||||
Float p_in_thread_buffer[blockwise_in_copy.GetThreadBufferSize()];
|
||||
Float p_wei_thread_buffer[blockwise_wei_copy.GetThreadBufferSize()];
|
||||
|
||||
__syncthreads();
|
||||
blockwise_in_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
blockwise_wei_copy.MoveSrcSliceWindow(Sequence<EPerBlock, 0>{}, True);
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
// LDS double buffer: load last data from device mem
|
||||
blockwise_in_copy.RunLoadThreadBuffer(
|
||||
p_in_global, p_in_thread_buffer, global_address_space, generic_address_space);
|
||||
blockwise_wei_copy.RunLoadThreadBuffer(
|
||||
p_wei_global, p_wei_thread_buffer, global_address_space, generic_address_space);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
// LDS double buffer: GEMM on 2nd-last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
// LDS double buffer: store last data to LDS
|
||||
blockwise_in_copy.RunStoreThreadBuffer(p_in_thread_buffer,
|
||||
p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreThreadBuffer(p_wei_thread_buffer,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
else // if has 1 iteration left
|
||||
{
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on last data
|
||||
blockwise_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output global descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
@@ -356,47 +366,48 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
// This is a hack, because slicing a merged dimension is not supported yet.
|
||||
// This should be replaced with logic above, once slicing a merged dimension support
|
||||
// become available
|
||||
// dst descriptor
|
||||
constexpr auto out_k0_k1_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<0, 3, 4>{});
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
|
||||
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
// dst descriptor
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
using OutThreadCopySliceLengths =
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
|
||||
constexpr index_t K0 = K / K1;
|
||||
constexpr index_t B0 = B / B1;
|
||||
|
||||
auto threadwise_out_copy =
|
||||
ThreadwiseGenericTensorSliceCopy_v2r1<decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
OutThreadCopySliceLengths,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
2,
|
||||
2,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B>(
|
||||
{0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global});
|
||||
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
|
||||
{
|
||||
threadwise_out_copy
|
||||
.template Run<Float, address_space_t::generic, address_space_t::global>(
|
||||
p_out_thread, p_out_global);
|
||||
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
|
||||
out_k_b_global_desc,
|
||||
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
|
||||
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
|
||||
}
|
||||
// output threadwise copy
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_k0_k1_b0_b1_thread_desc),
|
||||
decltype(out_k0_k1_b0_b1_global_desc),
|
||||
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 4, 1>::type,
|
||||
3,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1})
|
||||
#if 1
|
||||
.Run(p_out_thread, p_out_global, generic_address_space, global_address_space);
|
||||
#else // tweaking
|
||||
.Run_optimized_dst_address_calculation(
|
||||
p_out_thread, p_out_global, generic_address_space, global_address_space);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,27 +1,25 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_PADDED_LDS_DOUBLE_BUFFER_HPP
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R4_NCHW_KCYX_NKHW_HPP_LDS_DOUBLE_BUFFER_DEPRECATRD_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy.hpp"
|
||||
#include "blockwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
#include "blockwise_gemm.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// B = merge(N, Ho, Wo)
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
typename Float,
|
||||
typename InGlobalDesc,
|
||||
typename WeiGlobalDesc,
|
||||
typename OutGlobalDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
@@ -34,21 +32,21 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
typename InBlockCopySubLengths_E_B,
|
||||
typename InBlockCopyClusterLengths_E_B,
|
||||
typename InBlockCopyThreadClusterArrangeOrder,
|
||||
typename InBlockCopySrcAccessOrder,
|
||||
typename InBlockCopyDstAccessOrder,
|
||||
class InBlockCopySubLengths_E_B,
|
||||
class InBlockCopyClusterLengths_E_B,
|
||||
class InBlockCopyThreadClusterArrangeOrder,
|
||||
class InBlockCopySrcAccessOrder,
|
||||
class InBlockCopyDstAccessOrder,
|
||||
index_t InBlockCopyDataPerAccess_B,
|
||||
typename WeiBlockCopySubLengths_E_K,
|
||||
typename WeiBlockCopyClusterLengths_E_K,
|
||||
typename WeiBlockCopyThreadClusterArrangeOrder,
|
||||
typename WeiBlockCopySrcAccessOrder,
|
||||
typename WeiBlockCopyDstAccessOrder,
|
||||
class WeiBlockCopySubLengths_E_K,
|
||||
class WeiBlockCopyClusterLengths_E_K,
|
||||
class WeiBlockCopyThreadClusterArrangeOrder,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K,
|
||||
index_t OutThreadCopyDataPerAccess_B>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buffer
|
||||
struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_lds_double_buffer_deprecated
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
@@ -58,27 +56,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
|
||||
constexpr auto in_n_c_hi_wi_global_desc =
|
||||
make_native_tensor_descriptor(InGlobalDesc::GetLengths(), InGlobalDesc::GetStrides());
|
||||
constexpr auto wei_k_c_y_x_global_desc =
|
||||
make_native_tensor_descriptor(WeiGlobalDesc::GetLengths(), WeiGlobalDesc::GetStrides());
|
||||
constexpr auto out_n_k_ho_wo_global_desc =
|
||||
make_native_tensor_descriptor(OutGlobalDesc::GetLengths(), OutGlobalDesc::GetStrides());
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_k_c_y_x_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_n_k_h_w_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t N = in_n_c_hi_wi_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_n_c_hi_wi_global_desc.GetLength(I1);
|
||||
constexpr index_t Hi = in_n_c_hi_wi_global_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_n_c_hi_wi_global_desc.GetLength(I3);
|
||||
constexpr index_t N = in_n_c_h_w_global_desc.GetLengths()[0];
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t K = out_n_k_ho_wo_global_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_n_k_ho_wo_global_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_n_k_ho_wo_global_desc.GetLength(I3);
|
||||
constexpr index_t K = out_n_k_h_w_global_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_n_k_h_w_global_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_n_k_h_w_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLength(I3);
|
||||
constexpr index_t Y = wei_k_c_y_x_global_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_k_c_y_x_global_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
@@ -103,67 +97,65 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
constexpr index_t BBlockWork = B / BPerBlock;
|
||||
|
||||
constexpr auto block_work_desc =
|
||||
make_cluster_descriptor(Sequence<KBlockWork, BBlockWork>{});
|
||||
make_ConstantTensorDescriptor_packed(Sequence<KBlockWork, BBlockWork>{});
|
||||
|
||||
const auto block_work_id = block_work_desc.CalculateClusterIndex(get_block_1d_id());
|
||||
const auto block_work_multi_id =
|
||||
block_work_desc.GetMultiIndexFrom1dIndex(get_block_1d_id());
|
||||
|
||||
const index_t k_block_data_on_global = block_work_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_id[1] * BPerBlock;
|
||||
const index_t k_block_data_on_global = block_work_multi_id[0] * KPerBlock;
|
||||
const index_t b_block_data_on_global = block_work_multi_id[1] * BPerBlock;
|
||||
|
||||
// input tensor
|
||||
// global mem
|
||||
constexpr auto in_n_c_hip_wip_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hi_wi_global_desc,
|
||||
make_tuple(
|
||||
PassThrough<N>{}, PassThrough<C>{}, Pad<Sequence<Hi, Wi>, LeftPads, RightPads>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}));
|
||||
// tensor descriptor in device memory [N, Ho, Wo]
|
||||
constexpr auto in_n_ho_wo_global_desc =
|
||||
in_n_c_h_w_global_desc.Extract(I0, I2, I3)
|
||||
.StridedSlice(I1, Number<Ho>{}, Number<ConvStrideH>{})
|
||||
.StridedSlice(I2, Number<Wo>{}, Number<ConvStrideW>{});
|
||||
|
||||
constexpr auto in_n_c_y_ho_x_wo_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_hip_wip_global_desc,
|
||||
make_tuple(PassThrough<N>{},
|
||||
PassThrough<C>{},
|
||||
Embed<Sequence<Y, Ho>, Sequence<ConvDilationH, ConvStrideH, 0>>{},
|
||||
Embed<Sequence<X, Wo>, Sequence<ConvDilationW, ConvStrideW, 0>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{}));
|
||||
// batch descritpor for device memory
|
||||
constexpr auto in_c_y_x_global_desc =
|
||||
in_n_c_h_w_global_desc.StridedSlice(I2, Number<Y>{}, Number<ConvDilationH>{})
|
||||
.StridedSlice(I3, Number<X>{}, Number<ConvDilationW>{})
|
||||
.Extract(Sequence<1, 2, 3>{});
|
||||
|
||||
constexpr auto in_e_b_global_desc = transform_tensor_descriptor(
|
||||
in_n_c_y_ho_x_wo_global_desc,
|
||||
make_tuple(Merge<Sequence<C, Y, X>>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
// merged tensor descriptor in device memory [E, B], src of blockwise copy
|
||||
constexpr auto in_e_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(in_c_y_x_global_desc.Embed(in_n_ho_wo_global_desc),
|
||||
Sequence<0, 1, 2>{},
|
||||
Sequence<3, 4, 5>{});
|
||||
|
||||
// LDS mem
|
||||
// memory layout descriptor in LDS [E, B], dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto in_e_b_block_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<EPerBlock, BPerBlock>{});
|
||||
make_ConstantTensorDescriptor_packed(Sequence<EPerBlock, BPerBlock>{});
|
||||
|
||||
// input blockwise copy
|
||||
// slice a merged tensor, reorder and copy to a normal tensor
|
||||
// this copy operator already has blockwise offset built-in
|
||||
auto blockwise_in_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
InBlockCopyDataPerAccess_B>(
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
|
||||
decltype(in_e_b_global_desc),
|
||||
decltype(in_e_b_block_desc),
|
||||
decltype(in_e_b_block_desc.GetLengths()),
|
||||
InBlockCopySubLengths_E_B,
|
||||
InBlockCopyClusterLengths_E_B,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
1,
|
||||
1,
|
||||
InBlockCopyDataPerAccess_B,
|
||||
InBlockCopyDataPerAccess_B>(
|
||||
{0, b_block_data_on_global}, {0, 0});
|
||||
|
||||
// weight tensor
|
||||
// global mem
|
||||
constexpr auto wei_e_k_global_desc = reorder_tensor_descriptor_given_upper2lower(
|
||||
unfold_tensor_descriptor(wei_k_c_y_x_global_desc, I1, I3), Sequence<1, 0>{});
|
||||
// tensor descriptor in device memory, src of blockwise copy
|
||||
constexpr auto wei_e_k_global_desc =
|
||||
wei_k_c_y_x_global_desc.Unfold(I1, I3).ReorderGivenNew2Old(Sequence<1, 0>{});
|
||||
|
||||
// LDS
|
||||
// tensor descriptor in LDS, dst of blockwise copy
|
||||
// be careful of LDS alignment
|
||||
constexpr auto wei_e_k_block_desc = make_native_tensor_descriptor_aligned(
|
||||
constexpr auto wei_e_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<EPerBlock, KPerBlock>{},
|
||||
Number<math::lcm(WeiBlockCopyDstDataPerWrite_K, GemmDataPerReadA)>{});
|
||||
|
||||
@@ -173,21 +165,23 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
static_assert(wei_e_k_block_desc.GetStride(I0) % GemmDataPerReadA == 0,
|
||||
"GemmDataPerReadA alignment requirement is not satisfied");
|
||||
|
||||
// weight blockwise copy
|
||||
// operator for blockwise copy of weight into LDS
|
||||
// slice a tensor, and copy it into another tensor
|
||||
// this copy operator already have blockwise offset built-in
|
||||
auto blockwise_wei_copy =
|
||||
BlockwiseGenericTensorSliceCopy_v4<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
BlockwiseGenericTensorSliceCopy_v2_deprecated<BlockSize,
|
||||
decltype(wei_e_k_global_desc),
|
||||
decltype(wei_e_k_block_desc),
|
||||
decltype(wei_e_k_block_desc.GetLengths()),
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
0,
|
||||
1,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>(
|
||||
{0, k_block_data_on_global}, {0, 0});
|
||||
|
||||
// GEMM definition
|
||||
@@ -253,12 +247,14 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
// zero out threadwise output
|
||||
threadwise_matrix_set_zero(c_k0k1_b0b1_thread_mtx_desc, p_out_thread);
|
||||
|
||||
const Float* p_wei_block_on_global = p_wei_global;
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
blockwise_in_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_block_double);
|
||||
blockwise_in_copy.template Run<Float, AddressSpace::global>(p_in_global,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.template Run<Float, AddressSpace::global>(p_wei_global,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
@@ -289,12 +285,10 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy
|
||||
.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
@@ -317,9 +311,9 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
blockwise_in_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
|
||||
p_in_global, p_in_thread_buffer);
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, Float, address_space_t::global>(
|
||||
blockwise_wei_copy.template RunLoadThreadBuffer<Float, AddressSpace::global>(
|
||||
p_wei_global, p_wei_thread_buffer);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
@@ -342,6 +336,15 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
|
||||
// copy output: register to global memory
|
||||
{
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
|
||||
// define tensor descriptor for threadwise copy
|
||||
// output global descriptor, for calculating origin of thread tensor
|
||||
// in global memory
|
||||
constexpr auto out_k_b_global_desc = make_ConstantMergedTensorDescriptor(
|
||||
out_n_k_h_w_global_desc, Sequence<1>{}, Sequence<0, 2, 3>{});
|
||||
|
||||
// calculate origin of thread output tensor on global memory
|
||||
// blockwise GEMM c matrix starting index
|
||||
const auto c_thread_mtx_on_block =
|
||||
@@ -353,51 +356,46 @@ struct GridwiseConvolutionImplicitGemm_v4r4_nchw_kcyx_nkhw_padded_lds_double_buf
|
||||
const index_t b_thread_data_on_global =
|
||||
b_block_data_on_global + c_thread_mtx_on_block.col;
|
||||
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b0_b1_thread_desc = make_native_tensor_descriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat, GemmNPerThreadSubC>{});
|
||||
// This is a hack, because slicing a merged dimension is not supported yet.
|
||||
// This should be replaced with logic above, once slicing a merged dimension support
|
||||
// become available
|
||||
// dst descriptor
|
||||
constexpr auto out_k0_k1_b_global_desc =
|
||||
make_ConstantMergedTensorDescriptor(out_n_k_h_w_global_desc.Fold(I1, Number<K1>{}),
|
||||
Sequence<1>{},
|
||||
Sequence<2>{},
|
||||
Sequence<0, 3, 4>{});
|
||||
|
||||
// dst descriptor
|
||||
constexpr index_t K1 = GemmMPerThreadSubC * GemmMLevel0Cluster * GemmMLevel1Cluster;
|
||||
constexpr index_t B1 = GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
// src descriptor
|
||||
constexpr auto out_k0_k1_b_thread_desc = make_ConstantTensorDescriptor_packed(
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNRepeat * GemmNPerThreadSubC>{});
|
||||
|
||||
constexpr index_t K0 = K / K1;
|
||||
constexpr index_t B0 = B / B1;
|
||||
using OutThreadCopySliceLengths =
|
||||
Sequence<GemmMRepeat, GemmMPerThreadSubC, GemmNPerThreadSubC>;
|
||||
|
||||
constexpr auto out_k_b_global_desc = transform_tensor_descriptor(
|
||||
out_n_k_ho_wo_global_desc,
|
||||
make_tuple(PassThrough<K>{}, Merge<Sequence<N, Ho, Wo>>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2, 3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
constexpr auto out_k0_k1_b0_b1_global_desc = transform_tensor_descriptor(
|
||||
out_k_b_global_desc,
|
||||
make_tuple(UnMerge<Sequence<K0, K1>>{}, UnMerge<Sequence<B0, B1>>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 1>{}, Sequence<2, 3>{}));
|
||||
|
||||
// output threadwise copy
|
||||
ThreadwiseGenericTensorSliceCopy_v4r2<
|
||||
decltype(out_k0_k1_b0_b1_thread_desc),
|
||||
decltype(out_k0_k1_b0_b1_global_desc),
|
||||
decltype(out_k0_k1_b0_b1_thread_desc.GetLengths()),
|
||||
arithmetic_sequence_gen<0, 4, 1>::type,
|
||||
3,
|
||||
auto threadwise_out_copy = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<
|
||||
decltype(out_k0_k1_b_thread_desc),
|
||||
decltype(out_k0_k1_b_global_desc),
|
||||
OutThreadCopySliceLengths,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
arithmetic_sequence_gen<0, 3, 1>::type,
|
||||
2,
|
||||
2,
|
||||
OutThreadCopyDataPerAccess_B,
|
||||
OutThreadCopyDataPerAccess_B>({0, 0, 0, 0},
|
||||
OutThreadCopyDataPerAccess_B>({0, 0, 0},
|
||||
{k_thread_data_on_global / K1,
|
||||
k_thread_data_on_global % K1,
|
||||
b_thread_data_on_global / B1,
|
||||
b_thread_data_on_global % B1})
|
||||
#if 1
|
||||
.template Run<Float, Float, address_space_t::generic, address_space_t::global>
|
||||
#else // tweaking
|
||||
.template Run_optimized_dst_address_calculation<Float,
|
||||
Float,
|
||||
address_space_t::generic,
|
||||
address_space_t::global>
|
||||
#endif
|
||||
(p_out_thread, p_out_global);
|
||||
b_thread_data_on_global});
|
||||
|
||||
for(index_t nrepeat = 0; nrepeat < GemmNRepeat; ++nrepeat)
|
||||
{
|
||||
threadwise_out_copy
|
||||
.template Run<Float, AddressSpace::generic, AddressSpace::global>(p_out_thread,
|
||||
p_out_global);
|
||||
|
||||
threadwise_out_copy.MoveSrcSliceWindow(Sequence<0, 0, GemmNPerThreadSubC>{}, True);
|
||||
threadwise_out_copy.MoveDstSliceWindow(Sequence<0, 0, B1>{}, True);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_CONSTANT_MATRIX_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
@@ -32,6 +32,11 @@ struct ConstantMatrixDescriptor
|
||||
return irow * RowStride_ + icol;
|
||||
}
|
||||
|
||||
__host__ __device__ static index_t CalculateOffset(index_t irow, index_t icol)
|
||||
{
|
||||
return GetOffsetFromMultiIndex(irow, icol);
|
||||
}
|
||||
|
||||
template <index_t SubNRow, index_t SubNCol>
|
||||
__host__ __device__ static constexpr auto MakeSubMatrixDescriptor(Number<SubNRow>,
|
||||
Number<SubNCol>)
|
||||
@@ -54,9 +59,10 @@ __host__ __device__ constexpr auto
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ constexpr auto make_ConstantMatrixDescriptor(ConstantTensorDescriptor<Ts...>)
|
||||
__host__ __device__ constexpr auto
|
||||
make_ConstantMatrixDescriptor(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
using TDesc = ConstantTensorDescriptor<Ts...>;
|
||||
using TDesc = ConstantTensorDescriptor_deprecated<Ts...>;
|
||||
static_assert(TDesc::GetNumOfDimension() == 2, "wrong");
|
||||
static_assert(TDesc::GetStrides()[1] == 1, "wrong");
|
||||
return ConstantMatrixDescriptor<TDesc::GetLengths()[0],
|
||||
|
||||
@@ -1,26 +1,26 @@
|
||||
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_HPP
|
||||
#ifndef CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
|
||||
#define CK_CONSTANT_MERGED_TENSOR_DESCRIPTOR_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// OriginalTensorDesc : ConstantTensorDescriptor<...>
|
||||
// OriginalTensorDesc : ConstantTensorDescriptor_deprecated<...>
|
||||
// it's the tensor whose dimensions are to be merged
|
||||
// OriginalDimMergeSeqs : Sequence<...>...
|
||||
// each is a sequence of original dimensions (of OriginalTensorDesc) to be merged
|
||||
template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
struct ConstantMergedTensorDescriptor
|
||||
struct ConstantMergedTensorDescriptor_deprecated
|
||||
{
|
||||
using Type = ConstantMergedTensorDescriptor;
|
||||
using Type = ConstantMergedTensorDescriptor_deprecated;
|
||||
|
||||
static constexpr auto mOriginalDimMergeSeqs = std::tuple<OriginalDimMergeSeqs...>{};
|
||||
|
||||
static constexpr index_t nDim = sizeof...(OriginalDimMergeSeqs);
|
||||
static constexpr index_t nOriginalDim = OriginalTensorDesc::GetNumOfDimension();
|
||||
|
||||
__host__ __device__ constexpr ConstantMergedTensorDescriptor()
|
||||
__host__ __device__ constexpr ConstantMergedTensorDescriptor_deprecated()
|
||||
{
|
||||
static_assert(nDim <= nOriginalDim, "wrong!");
|
||||
|
||||
@@ -189,7 +189,7 @@ struct ConstantMergedTensorDescriptor
|
||||
{
|
||||
constexpr auto lengths = GetLengths();
|
||||
constexpr auto strides = calculate_tensor_strides_packed(lengths);
|
||||
return ConstantTensorDescriptor<decltype(lengths), decltype(strides)>{};
|
||||
return ConstantTensorDescriptor_deprecated<decltype(lengths), decltype(strides)>{};
|
||||
}
|
||||
};
|
||||
|
||||
@@ -197,7 +197,7 @@ template <class OriginalTensorDesc, class... OriginalDimMergeSeqs>
|
||||
__host__ __device__ constexpr auto make_ConstantMergedTensorDescriptor(OriginalTensorDesc,
|
||||
OriginalDimMergeSeqs...)
|
||||
{
|
||||
return ConstantMergedTensorDescriptor<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
|
||||
return ConstantMergedTensorDescriptor_deprecated<OriginalTensorDesc, OriginalDimMergeSeqs...>{};
|
||||
}
|
||||
|
||||
template <class TDesc>
|
||||
@@ -1,12 +1,12 @@
|
||||
#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_CONSTANT_TENSOR_DESCRIPTOR_HPP
|
||||
#ifndef CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
|
||||
#define CK_CONSTANT_TENSOR_DESCRIPTOR_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed_old(Lengths)
|
||||
__host__ __device__ constexpr auto calculate_tensor_strides_packed_deprecated(Lengths)
|
||||
{
|
||||
return reverse_inclusive_scan_sequence(
|
||||
Lengths{}.PopFront(), math::multiplies<index_t>{}, Number<1>{})
|
||||
@@ -19,18 +19,18 @@ __host__ __device__ constexpr auto calculate_tensor_strides_aligned_old(Lengths,
|
||||
constexpr index_t L_back_align =
|
||||
Align * math::integer_divide_ceiler<index_t>{}(Lengths{}.Back(), Align);
|
||||
|
||||
return calculate_tensor_strides_packed_old(
|
||||
return calculate_tensor_strides_packed_deprecated(
|
||||
Lengths{}.Modify(Number<Lengths{}.GetSize() - 1>{}, Number<L_back_align>{}));
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
struct ConstantTensorDescriptor_deprecated
|
||||
{
|
||||
using Type = ConstantTensorDescriptor;
|
||||
using Type = ConstantTensorDescriptor_deprecated;
|
||||
|
||||
static constexpr index_t nDim = Lengths::GetSize();
|
||||
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor()
|
||||
__host__ __device__ constexpr ConstantTensorDescriptor_deprecated()
|
||||
{
|
||||
static_assert(Lengths::GetSize() == Strides::GetSize(), "nDim not consistent");
|
||||
}
|
||||
@@ -186,7 +186,7 @@ struct ConstantTensorDescriptor
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
using PackedStrides = decltype(calculate_tensor_strides_packed_old(GetLengths()));
|
||||
using PackedStrides = decltype(calculate_tensor_strides_packed_deprecated(GetLengths()));
|
||||
|
||||
// calculate index in each of the dimensions in the order of their dimension
|
||||
static_for<0, nDim - 1, 1>{}(lambda_GetMultiIndexFrom1dIndex<PackedStrides>(id, multi_id));
|
||||
@@ -284,7 +284,7 @@ struct ConstantTensorDescriptor
|
||||
using extract_lengths = decltype(Lengths::Extract(extract_dims...));
|
||||
using extract_strides = decltype(Strides::Extract(extract_dims...));
|
||||
|
||||
return ConstantTensorDescriptor<extract_lengths, extract_strides>{};
|
||||
return ConstantTensorDescriptor_deprecated<extract_lengths, extract_strides>{};
|
||||
}
|
||||
|
||||
template <index_t... IDims>
|
||||
@@ -294,13 +294,13 @@ struct ConstantTensorDescriptor
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor<Ts...>)
|
||||
__host__ __device__ static constexpr auto Embed(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
using leaf_tensor = ConstantTensorDescriptor<Ts...>;
|
||||
using leaf_tensor = ConstantTensorDescriptor_deprecated<Ts...>;
|
||||
|
||||
return ConstantTensorDescriptor<decltype(GetLengths().PushBack(leaf_tensor::GetLengths())),
|
||||
decltype(
|
||||
GetStrides().PushBack(leaf_tensor::GetStrides()))>{};
|
||||
return ConstantTensorDescriptor_deprecated<
|
||||
decltype(GetLengths().PushBack(leaf_tensor::GetLengths())),
|
||||
decltype(GetStrides().PushBack(leaf_tensor::GetStrides()))>{};
|
||||
}
|
||||
|
||||
template <index_t IDimVector, index_t DataPerVector>
|
||||
@@ -351,7 +351,7 @@ struct ConstantTensorDescriptor
|
||||
using vectorized_strides =
|
||||
decltype((Strides{} / Number<DataPerVector>{}).Modify(Number<IDim>{}, Number<1>{}));
|
||||
|
||||
return ConstantTensorDescriptor<vectorized_lengths, vectorized_strides>{};
|
||||
return ConstantTensorDescriptor_deprecated<vectorized_lengths, vectorized_strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLen>
|
||||
@@ -359,7 +359,7 @@ struct ConstantTensorDescriptor
|
||||
{
|
||||
using slice_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLen>{}));
|
||||
|
||||
return ConstantTensorDescriptor<slice_lengths, Strides>{};
|
||||
return ConstantTensorDescriptor_deprecated<slice_lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
@@ -367,7 +367,7 @@ struct ConstantTensorDescriptor
|
||||
{
|
||||
static_assert(slice_lengths.GetSize() == nDim, "wrong!");
|
||||
|
||||
return ConstantTensorDescriptor<decltype(slice_lengths), Strides>{};
|
||||
return ConstantTensorDescriptor_deprecated<decltype(slice_lengths), Strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t SliceLength, index_t SliceStride>
|
||||
@@ -379,7 +379,7 @@ struct ConstantTensorDescriptor
|
||||
using new_lengths = decltype(Lengths::Modify(Number<IDim>{}, Number<SliceLength>{}));
|
||||
using new_strides = decltype(Strides::Modify(Number<IDim>{}, Number<new_stride>{}));
|
||||
|
||||
return ConstantTensorDescriptor<new_lengths, new_strides>{};
|
||||
return ConstantTensorDescriptor_deprecated<new_lengths, new_strides>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldIntervals>
|
||||
@@ -418,7 +418,7 @@ struct ConstantTensorDescriptor
|
||||
constexpr auto new_strides =
|
||||
GetStrides().Extract(left).PushBack(fold_strides).PushBack(GetStrides().Extract(right));
|
||||
|
||||
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
|
||||
return ConstantTensorDescriptor_deprecated<decltype(new_lengths), decltype(new_strides)>{};
|
||||
}
|
||||
|
||||
template <index_t IDim, index_t... FoldIntervals>
|
||||
@@ -462,54 +462,55 @@ struct ConstantTensorDescriptor
|
||||
.PushBack(Number<unfold_stride>{})
|
||||
.PushBack(GetStrides().Extract(right));
|
||||
|
||||
return ConstantTensorDescriptor<decltype(new_lengths), decltype(new_strides)>{};
|
||||
return ConstantTensorDescriptor_deprecated<decltype(new_lengths), decltype(new_strides)>{};
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr auto Pack()
|
||||
{
|
||||
using packed_strides = decltype(calculate_tensor_strides_packed_old(Lengths{}));
|
||||
return ConstantTensorDescriptor<Lengths, packed_strides>{};
|
||||
using packed_strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{}));
|
||||
return ConstantTensorDescriptor_deprecated<Lengths, packed_strides>{};
|
||||
}
|
||||
|
||||
template <class MapNew2Old>
|
||||
__host__ __device__ static constexpr auto ReorderGivenNew2Old(MapNew2Old)
|
||||
{
|
||||
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
|
||||
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
|
||||
return ConstantTensorDescriptor_deprecated<
|
||||
decltype(Lengths::ReorderGivenNew2Old(MapNew2Old{})),
|
||||
decltype(Strides::ReorderGivenNew2Old(MapNew2Old{}))>{};
|
||||
}
|
||||
|
||||
template <class MapOld2New>
|
||||
__host__ __device__ static constexpr auto ReorderGivenOld2New(MapOld2New)
|
||||
{
|
||||
return ConstantTensorDescriptor<decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
|
||||
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{};
|
||||
return ConstantTensorDescriptor_deprecated<
|
||||
decltype(Lengths::ReorderGivenOld2New(MapOld2New{})),
|
||||
decltype(Strides::ReorderGivenOld2New(MapOld2New{}))>{};
|
||||
}
|
||||
};
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_packed(Lengths)
|
||||
{
|
||||
using Strides = decltype(calculate_tensor_strides_packed_old(Lengths{}));
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
using Strides = decltype(calculate_tensor_strides_packed_deprecated(Lengths{}));
|
||||
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Strides)
|
||||
{
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, index_t Align>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
using Strides = decltype(calculate_tensor_strides_aligned_old(Lengths{}, Number<Align>{}));
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
return ConstantTensorDescriptor_deprecated<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ void
|
||||
print_ConstantTensorDescriptor(const char* s,
|
||||
ConstantTensorDescriptor<Sequence<Lengths...>, Sequence<Strides...>>)
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(
|
||||
const char* s, ConstantTensorDescriptor_deprecated<Sequence<Lengths...>, Sequence<Strides...>>)
|
||||
{
|
||||
constexpr index_t ndim = sizeof...(Lengths);
|
||||
|
||||
@@ -0,0 +1,173 @@
|
||||
#ifndef CK_PRINT_TENSOR_DESCRIPTOR_HPP
|
||||
#define CK_PRINT_TENSOR_DESCRIPTOR_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename... NativeDimensions>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ void print_tensor_descriptor(const char* s,
|
||||
const TransformedTensorDescriptor<Ts...>& desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths());
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
|
||||
{
|
||||
constexpr index_t nDim = sizeof...(Lengths);
|
||||
|
||||
static_assert(nDim > 0 && nDim <= 12, "wrong!");
|
||||
|
||||
static_if<nDim == 1>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 2>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 3>{}([&](auto) {
|
||||
printf(
|
||||
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 4>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 5>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 6>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 7>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
|
||||
"%u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
|
||||
"%u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t... Lengths>
|
||||
__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>)
|
||||
{
|
||||
constexpr index_t nDim = sizeof...(Lengths);
|
||||
|
||||
static_assert(nDim > 0 && nDim <= 12, "wrong!");
|
||||
|
||||
static_if<nDim == 1>{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 2>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 3>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 4>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 5>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 6>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 7>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_TENSOR_COORDINATE_V2_HPP
|
||||
#define CK_TENSOR_COORDINATE_V2_HPP
|
||||
#ifndef CK_TENSOR_COORDINATE_HPP
|
||||
#define CK_TENSOR_COORDINATE_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "dimension.hpp"
|
||||
@@ -8,9 +8,24 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// A "tensor cooridnate" is an opaque object that represents a "point of location" inside a tensor
|
||||
// At the bare minimun, user should be able to query the following information from a tensor
|
||||
// coordinate:
|
||||
// 1. Tensor descriptor
|
||||
// 2. Location, represented in the form of multi-index
|
||||
// 3. Location, represented in the form of the offset to the origin of the tensor
|
||||
// 4. If the location is inside invalid area or not, i.e. the padding area of an implicitly padded
|
||||
// tensor is considered invalid, because the padding area doesn't have any physical memory
|
||||
// allocation
|
||||
// A tensor cooridnate also provides following functionality:
|
||||
// 1. Given step size in each dimension, update itself, or return a new tensor cooridnate, so user
|
||||
// can freely move the "point of location" inside the tensor
|
||||
|
||||
// wrapper class for NativeTensorCoordinate and TransformedTensorCoordinate
|
||||
template <typename TensorDesc>
|
||||
struct TensorCoordinate;
|
||||
|
||||
// tensor coordinate for native tensor
|
||||
template <typename NativeTensorDesc>
|
||||
struct NativeTensorCoordinate
|
||||
{
|
||||
@@ -78,12 +93,10 @@ struct NativeTensorCoordinate
|
||||
return coord;
|
||||
}
|
||||
|
||||
#if 0 // tweaking
|
||||
__host__ __device__ static constexpr index_t CalculateOffsetDiff(const Index& idx_diff)
|
||||
{
|
||||
return tensor_desc_type::CalculateOffsetDiff(idx_diff);
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr bool IsUpperIndexMappedToValidOffset() { return true; }
|
||||
|
||||
@@ -96,6 +109,7 @@ struct NativeTensorCoordinate
|
||||
index_t mOffset;
|
||||
};
|
||||
|
||||
// tensor coordinate for transformed tensor
|
||||
template <typename TransformedTensorDesc>
|
||||
struct TransformedTensorCoordinate
|
||||
{
|
||||
@@ -177,10 +191,10 @@ struct TransformedTensorCoordinate
|
||||
return coord_up;
|
||||
}
|
||||
|
||||
#if 0 // tweaking
|
||||
// Calculate offset diff without updating tensor-coordinate
|
||||
// If idx_up_diff is know at compile time, and has only non-zero entries on linear dimensions,
|
||||
// then all calculation can be done at compile-time.
|
||||
// TODO: this function is not compiled to expected ISA
|
||||
__host__ __device__ constexpr index_t CalculateOffsetDiff(const UpperIndex& idx_up_diff) const
|
||||
{
|
||||
// For transformation of multi-index difference, not all transformation functions need to
|
||||
@@ -191,7 +205,6 @@ struct TransformedTensorCoordinate
|
||||
|
||||
return GetLowerCoordinate().CalculateOffsetDiff(idx_low_diff);
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ constexpr bool IsUpperIndexMappedToValidOffset() const
|
||||
{
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
#define CK_TENSOR_COORDINATE_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// TensorDesc is ConstantTensorDescriptor
|
||||
// TensorDesc is ConstantTensorDescriptor_deprecated
|
||||
template <class TensorDesc>
|
||||
struct NormalTensorCoordinate_deprecated
|
||||
{
|
||||
@@ -95,18 +95,19 @@ struct NormalTensorCoordinate_deprecated
|
||||
index_t mOffset;
|
||||
};
|
||||
|
||||
// TensorDesc is ConstantMergedTensorDescriptor
|
||||
// TensorDesc is ConstantMergedTensorDescriptor_deprecated
|
||||
template <class TensorDesc>
|
||||
struct MergedTensorCoordinate
|
||||
struct MergedTensorCoordinate_deprecated
|
||||
{
|
||||
using type = MergedTensorCoordinate;
|
||||
using type = MergedTensorCoordinate_deprecated;
|
||||
using tensor_desc_type = TensorDesc;
|
||||
|
||||
static constexpr index_t nDim = tensor_desc_type::GetNumOfDimension();
|
||||
static constexpr index_t nOriginalDim =
|
||||
tensor_desc_type::GetOriginalTensorDescriptor().GetNumOfDimension();
|
||||
|
||||
__host__ __device__ constexpr MergedTensorCoordinate(Array<index_t, nDim> tensor_index)
|
||||
__host__
|
||||
__device__ constexpr MergedTensorCoordinate_deprecated(Array<index_t, nDim> tensor_index)
|
||||
: mOriginalIndex{tensor_desc_type::GetOriginalMultiIndexFromMultiIndex(tensor_index)}
|
||||
{
|
||||
// partial offset on each dimension
|
||||
@@ -127,8 +128,8 @@ struct MergedTensorCoordinate
|
||||
}
|
||||
|
||||
template <class... Xs>
|
||||
__host__ __device__ constexpr MergedTensorCoordinate(Xs... xs)
|
||||
: MergedTensorCoordinate(Array<index_t, nDim>{xs...})
|
||||
__host__ __device__ constexpr MergedTensorCoordinate_deprecated(Xs... xs)
|
||||
: MergedTensorCoordinate_deprecated(Array<index_t, nDim>{xs...})
|
||||
{
|
||||
}
|
||||
|
||||
@@ -311,7 +312,7 @@ struct MergedTensorCoordinate
|
||||
// dimensions, and those merged dimensions, that would never be involved in index
|
||||
// arithmetic after construction of TensorCoordinate.
|
||||
// TODO: refactor TensorCoordinate, after introducing the concept of "dimensions"
|
||||
// and simplify implementation of ConstantMergedTensorDescriptor, so we don't need to
|
||||
// and simplify implementation of ConstantMergedTensorDescriptor_deprecated, so we don't need to
|
||||
// count on compiler to optimize away those register memory for us
|
||||
Array<index_t, nOriginalDim> mOriginalIndex;
|
||||
Array<index_t, nDim> mPartialOffsets;
|
||||
@@ -326,16 +327,17 @@ struct TensorCoordinate_deprecated
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor<Ts...>>();
|
||||
return NormalTensorCoordinate_deprecated<ConstantTensorDescriptor_deprecated<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor_deprecated<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate<ConstantMergedTensorDescriptor<Ts...>>();
|
||||
return MergedTensorCoordinate_deprecated<
|
||||
ConstantMergedTensorDescriptor_deprecated<Ts...>>();
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
#ifndef CK_TENSOR_COORDINATE_HELPER_HPP
|
||||
#define CK_TENSOR_COORDINATE_HELPER_HPP
|
||||
|
||||
#include "tensor_coordiante_v2.hpp"
|
||||
#include "tensor_coordiante_hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename TensorDesc>
|
||||
__host__ __device__ constexpr auto
|
||||
make_tensor_coordinate_v2(TensorDesc, MultiIndex<TensorDesc::GetNumOfDimension()> idx)
|
||||
make_tensor_coordinate(TensorDesc, MultiIndex<TensorDesc::GetNumOfDimension()> idx)
|
||||
{
|
||||
return typename TensorCoordinate<TensorDesc>::type(idx);
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
// tensor descriptor for "native tensor"
|
||||
// A "native tensor" is a "true" tensor that can be represented by Lengths and Strides
|
||||
template <typename... NativeDimensions>
|
||||
struct NativeTensorDescriptor
|
||||
{
|
||||
@@ -113,12 +115,10 @@ struct NativeTensorDescriptor
|
||||
|
||||
__host__ __device__ static constexpr auto GetNonLinearDimensions() { return Sequence<>{}; }
|
||||
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
|
||||
{
|
||||
return Tuple<>{};
|
||||
}
|
||||
#endif
|
||||
|
||||
__host__ __device__ static constexpr bool
|
||||
IsUpperIndexMappedToValidOffset(const Index& /* idx */)
|
||||
@@ -127,14 +127,11 @@ struct NativeTensorDescriptor
|
||||
}
|
||||
};
|
||||
|
||||
// LowerTensorDescriptor
|
||||
// Transforms: Tuple<DimensionTransforms...>
|
||||
// LowerDimensionIds: Tuple<Sequence<...>>
|
||||
// UpperDimensionIds: Tuple<Sequence<...>>
|
||||
template <typename LowTensorDescriptor,
|
||||
typename Transforms,
|
||||
typename LowDimensionIds,
|
||||
typename UpDimensionIds>
|
||||
// Tensor descriptor for "transformed tensor"
|
||||
template <typename LowTensorDescriptor, // NativeTensorDescriptor or TransformedTensorDescriptor
|
||||
typename Transforms, // Tuple<MultIndexTransforms...>
|
||||
typename LowDimensionIds, // Tuple<Sequence<...>>
|
||||
typename UpDimensionIds> // Tuple<Sequence<...>>
|
||||
struct TransformedTensorDescriptor
|
||||
{
|
||||
using type = TransformedTensorDescriptor;
|
||||
@@ -412,6 +409,7 @@ struct TransformedTensorDescriptor
|
||||
{
|
||||
#if 0
|
||||
// create tuple of linear dimension masks, for all transformations
|
||||
// TODO: this doesn't compile, because transform_tuples() complain about constexpr
|
||||
constexpr auto tuple_of_linear_dimension_mask =
|
||||
transform_tuples(lambda_get_linear_dimension_mask_of_single_tranform{},
|
||||
Transforms{},
|
||||
@@ -419,7 +417,7 @@ struct TransformedTensorDescriptor
|
||||
UpDimensionIds{});
|
||||
#else
|
||||
// create tuple of linear dimension masks, for all transformations
|
||||
// TODO: this is a hack, transform_tuples() doesn't compile, complain about constexpr
|
||||
// TODO: this is a hack
|
||||
constexpr auto tuple_of_linear_dimension_mask = dummy_transform_tuples_impl(
|
||||
lambda_get_linear_dimension_mask_of_single_tranform{},
|
||||
Transforms{},
|
||||
@@ -465,7 +463,7 @@ struct TransformedTensorDescriptor
|
||||
#if 0
|
||||
__host__ __device__ static constexpr auto GetNonLinearIndependentDimensionGroups()
|
||||
{
|
||||
// not implemented
|
||||
// TODO: not implemented
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -63,10 +63,11 @@ template <typename LowerTensorDescriptor,
|
||||
index_t... LowerLengths,
|
||||
index_t... LowerDimensionIds,
|
||||
index_t... UpperDimensionIds>
|
||||
__host__ __device__ constexpr auto reorder_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
Sequence<LowerLengths...>,
|
||||
Sequence<LowerDimensionIds...>,
|
||||
Sequence<UpperDimensionIds...>)
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_transformed_tensor_descriptor_impl(LowerTensorDescriptor,
|
||||
Sequence<LowerLengths...>,
|
||||
Sequence<LowerDimensionIds...>,
|
||||
Sequence<UpperDimensionIds...>)
|
||||
{
|
||||
return TransformedTensorDescriptor<LowerTensorDescriptor,
|
||||
Tuple<PassThrough<LowerLengths>...>,
|
||||
@@ -74,17 +75,40 @@ __host__ __device__ constexpr auto reorder_tensor_descriptor_impl(LowerTensorDes
|
||||
Tuple<Sequence<UpperDimensionIds>...>>{};
|
||||
}
|
||||
|
||||
template <typename LowerTensorDescriptor, typename MapLower2Upper>
|
||||
// reorder a NativeTensorDescriptor
|
||||
template <typename... Ts, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(LowerTensorDescriptor, MapLower2Upper)
|
||||
reorder_tensor_descriptor_given_lower2upper(NativeTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
|
||||
return reorder_tensor_descriptor_impl(
|
||||
LowerTensorDescriptor{},
|
||||
LowerTensorDescriptor::GetLengths(),
|
||||
typename arithmetic_sequence_gen<0, LowerTensorDescriptor::GetNumOfDimension(), 1>::type{},
|
||||
constexpr auto old_desc = NativeTensorDescriptor<Ts...>{};
|
||||
|
||||
static_assert(old_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
|
||||
|
||||
constexpr auto new_lengths = old_desc.GetLengths().ReorderGivenOld2New(MapLower2Upper{});
|
||||
constexpr auto new_strides = old_desc.GetStrides().ReorderGivenOld2New(MapLower2Upper{});
|
||||
|
||||
return make_native_tensor_descriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
// reorder a TransformedTensorDescriptor
|
||||
template <typename... Ts, typename MapLower2Upper>
|
||||
__host__ __device__ constexpr auto
|
||||
reorder_tensor_descriptor_given_lower2upper(TransformedTensorDescriptor<Ts...>, MapLower2Upper)
|
||||
{
|
||||
static_assert(is_valid_sequence_map<MapLower2Upper>{},
|
||||
"wrong! MapLower2Upper is not a valid map");
|
||||
|
||||
constexpr auto low_desc = TransformedTensorDescriptor<Ts...>{};
|
||||
|
||||
static_assert(low_desc.GetNumOfDimension() == MapLower2Upper::Size(), "wrong!");
|
||||
|
||||
return reorder_transformed_tensor_descriptor_impl(
|
||||
low_desc,
|
||||
low_desc.GetLengths(),
|
||||
typename arithmetic_sequence_gen<0, low_desc.GetNumOfDimension(), 1>::type{},
|
||||
MapLower2Upper{});
|
||||
}
|
||||
|
||||
@@ -97,7 +121,7 @@ __host__ __device__ constexpr auto
|
||||
}
|
||||
|
||||
template <typename Lengths, typename Strides>
|
||||
__host__ __device__ constexpr bool AreDimensionsUnfoldable(Lengths, Strides)
|
||||
__host__ __device__ constexpr bool are_dimensions_unfoldable(Lengths, Strides)
|
||||
{
|
||||
static_assert(Lengths::Size() == Strides::Size(), "wrong!");
|
||||
|
||||
@@ -129,7 +153,7 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript
|
||||
constexpr auto right = typename arithmetic_sequence_gen<LastUnfoldDim + 1, nDim, 1>::type{};
|
||||
|
||||
// sanity-checknfoldable
|
||||
static_assert(AreDimensionsUnfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
|
||||
static_assert(are_dimensions_unfoldable(desc.GetLengths(middle), desc.GetStrides(middle)),
|
||||
"wrong! not unfoldable");
|
||||
|
||||
// unfolded length, stride
|
||||
@@ -148,30 +172,6 @@ __host__ __device__ constexpr auto unfold_tensor_descriptor(NativeTensorDescript
|
||||
return make_native_tensor_descriptor(new_lengths, new_strides);
|
||||
}
|
||||
|
||||
#if 0
|
||||
// not implemented
|
||||
template <typename LowerTensorDescriptor,
|
||||
typename PadDimensionIds,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
__host__ __device__ constexpr auto
|
||||
pad_tensor_descriptor(LowerTensorDescriptor, PadLowerDimensionIds, LeftPads, RightPads)
|
||||
{
|
||||
constexpr index_t nDim = LowerTensorDescriptor::GetNumOfDimension();
|
||||
|
||||
constexpr auto non_pad_low_dim_ids = xxx;
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
LowerTensorDescriptor{},
|
||||
make_tuple(Pad<decltype(LowerTensorDescriptor::GetLengths(PadLowerDimensionIds{})),
|
||||
LeftPads,
|
||||
RightPads>{})
|
||||
.PushBack(PassThrough<xxxx>...),
|
||||
make_tuple(PadLowerDimensionIds{}).PushBack(xxxx),
|
||||
sequence_to_tuple(typename arithmetic_sequence_gen<0, nDim, 1> i::type{}));
|
||||
}
|
||||
#endif
|
||||
|
||||
// a cluster map 1d index to N-d index
|
||||
template <typename Lengths, typename ArrangeOrder>
|
||||
struct ClusterDescriptor
|
||||
@@ -205,169 +205,7 @@ template <typename Lengths,
|
||||
__host__ __device__ constexpr auto make_cluster_descriptor(
|
||||
Lengths, ArrangeOrder order = typename arithmetic_sequence_gen<0, Lengths::Size(), 1>::type{})
|
||||
{
|
||||
return ClusterDescriptor<Lengths, ArrangeOrder>{};
|
||||
}
|
||||
|
||||
template <typename... NativeDimensions>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor(const char* s, const NativeTensorDescriptor<NativeDimensions...>& desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths(), desc.GetStrides());
|
||||
}
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ void print_tensor_descriptor(const char* s,
|
||||
const TransformedTensorDescriptor<Ts...>& desc)
|
||||
{
|
||||
print_tensor_descriptor_impl(s, desc.GetLengths());
|
||||
}
|
||||
|
||||
template <index_t... Lengths, index_t... Strides>
|
||||
__host__ __device__ void
|
||||
print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>, Sequence<Strides...>)
|
||||
{
|
||||
constexpr index_t nDim = sizeof...(Lengths);
|
||||
|
||||
static_assert(nDim > 0 && nDim <= 12, "wrong!");
|
||||
|
||||
static_if<nDim == 1>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u}, strides {%u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 2>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u}, strides {%u %u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 3>{}([&](auto) {
|
||||
printf(
|
||||
"%s dim %u, lengths {%u %u %u}, strides {%u %u %u}\n", s, nDim, Lengths..., Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 4>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u}, strides {%u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 5>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u}, strides {%u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 6>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u}, strides {%u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 7>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u %u "
|
||||
"%u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u %u "
|
||||
"%u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
|
||||
static_if<nDim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}, strides {%u %u %u %u %u "
|
||||
"%u %u %u %u "
|
||||
"%u %u %u}\n",
|
||||
s,
|
||||
nDim,
|
||||
Lengths...,
|
||||
Strides...);
|
||||
});
|
||||
}
|
||||
|
||||
template <index_t... Lengths>
|
||||
__host__ __device__ void print_tensor_descriptor_impl(const char* s, Sequence<Lengths...>)
|
||||
{
|
||||
constexpr index_t nDim = sizeof...(Lengths);
|
||||
|
||||
static_assert(nDim > 0 && nDim <= 12, "wrong!");
|
||||
|
||||
static_if<nDim == 1>{}([&](auto) { printf("%s dim %u, lengths {%u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 2>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 3>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 4>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 5>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 6>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u}, \n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 7>{}(
|
||||
[&](auto) { printf("%s dim %u, lengths {%u %u %u %u %u %u %u}\n", s, nDim, Lengths...); });
|
||||
|
||||
static_if<nDim == 8>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 9>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 10>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 11>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
|
||||
static_if<nDim == 12>{}([&](auto) {
|
||||
printf("%s dim %u, lengths {%u %u %u %u %u %u %u %u %u %u %u %u}\n", s, nDim, Lengths...);
|
||||
});
|
||||
return ClusterDescriptor<Lengths, decltype(order)>{};
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -5,19 +5,17 @@
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "threadwise_gemm.hpp"
|
||||
|
||||
#ifndef CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
#define CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
// if following number are power of 2, index calculation shall be greatly reduced:
|
||||
// blockwise GEMM: C += transpose(A) * B
|
||||
// A and B are visable to the whole block, C is distributed among each thread
|
||||
// If following number are power of 2, index calculation shall be greatly reduced:
|
||||
// MPerThreadSubC, NPerThreadSubC, MLevel0ThreadCluster, NLevel0ThreadCluster,
|
||||
// MLevel1ThreadCluster, NLevel1ThreadCluster
|
||||
template <index_t BlockSize,
|
||||
class BlockMatrixA,
|
||||
class BlockMatrixB,
|
||||
class ThreadMatrixC,
|
||||
typename BlockMatrixA,
|
||||
typename BlockMatrixB,
|
||||
typename ThreadMatrixC,
|
||||
index_t MPerThreadSubC,
|
||||
index_t NPerThreadSubC,
|
||||
index_t MLevel0ThreadCluster,
|
||||
@@ -117,233 +115,19 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
n_repeat * NPerLevel1Cluster + n_in_sub_c};
|
||||
}
|
||||
|
||||
#if CK_USE_AMD_INLINE_ASM
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_amd_asm(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void
|
||||
Run_naive(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
{
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert(is_same<FloatA, float>{} && is_same<FloatB, float>{} &&
|
||||
is_same<FloatC, float>{},
|
||||
"Run_amd_asm only deal with float");
|
||||
|
||||
static_assert(MPerThreadSubC == 4 && NPerThreadSubC == 4 && KPerThreadLoop == 1 &&
|
||||
MPerThread == 8 && NPerThread == 8,
|
||||
"Run_amd_asm cannot deal with this GEMM shape yet");
|
||||
|
||||
static_assert(DataPerReadA == 4 && DataPerReadB == 4, "Run_amd_asm only do float4 read");
|
||||
|
||||
using Float4 = vector_type<float, 4>::MemoryType;
|
||||
|
||||
Float4* reg_a = reinterpret_cast<Float4*>(p_a_thread);
|
||||
Float4* reg_b = reinterpret_cast<Float4*>(p_b_thread);
|
||||
Float4* reg_c = reinterpret_cast<Float4*>(p_c_thread);
|
||||
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB]);
|
||||
reg_b[1] =
|
||||
*reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + NPerLevel1Cluster]);
|
||||
reg_a[1] =
|
||||
*reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + MPerLevel1Cluster]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = *reinterpret_cast<const Float4*>(&p_a_block[mMyThreadOffsetA + k * M]);
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
reg_b[0] = *reinterpret_cast<const Float4*>(&p_b_block[mMyThreadOffsetB + k * N]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
reg_b[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_b_block[mMyThreadOffsetB + k * N + NPerLevel1Cluster]);
|
||||
reg_a[1] = *reinterpret_cast<const Float4*>(
|
||||
&p_a_block[mMyThreadOffsetA + k * M + MPerLevel1Cluster]);
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
|
||||
__device__ void Run_amd_asm_v2(const float* __restrict__ p_a_block,
|
||||
const float* __restrict__ p_b_block,
|
||||
float* __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t M = a_block_mtx.NCol();
|
||||
constexpr index_t N = b_block_mtx.NCol();
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
float p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
float p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MThreadCluster = MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NThreadCluster = NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t MDataCluster = M / MPerThreadSubC;
|
||||
constexpr index_t NDataCluster = N / NPerThreadSubC;
|
||||
|
||||
constexpr index_t MRepeat = MDataCluster / MThreadCluster;
|
||||
constexpr index_t NRepeat = NDataCluster / NThreadCluster;
|
||||
|
||||
// assertion for inline asm
|
||||
static_assert((MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 &&
|
||||
KPerThreadLoop == 1) ||
|
||||
(MPerThreadSubC == 2 && NPerThreadSubC == 4 && MRepeat == 2 &&
|
||||
NRepeat == 2 && KPerThreadLoop == 1),
|
||||
"Run_amd_asm cannot deal with this GEMM shape yet");
|
||||
|
||||
static_assert(DataPerReadA == MPerThreadSubC && DataPerReadB == NPerThreadSubC,
|
||||
"wrong! Run_amd_asm doesn't support this config");
|
||||
|
||||
if(MPerThreadSubC == 4 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 &&
|
||||
KPerThreadLoop == 1)
|
||||
{
|
||||
using float4_type = vector_type<float, 4>::MemoryType;
|
||||
|
||||
float4_type* reg_a = reinterpret_cast<float4_type*>(p_a_thread);
|
||||
float4_type* reg_b = reinterpret_cast<float4_type*>(p_b_thread);
|
||||
float4_type* reg_c = reinterpret_cast<float4_type*>(p_c_thread);
|
||||
|
||||
const float4_type* p_a =
|
||||
reinterpret_cast<const float4_type*>(&p_a_block[mMyThreadOffsetA]);
|
||||
const float4_type* p_b =
|
||||
reinterpret_cast<const float4_type*>(&p_b_block[mMyThreadOffsetB]);
|
||||
|
||||
reg_a[0] = p_a[0];
|
||||
reg_b[0] = p_b[0];
|
||||
reg_b[1] = p_b[NThreadCluster];
|
||||
reg_a[1] = p_a[MThreadCluster];
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = p_a[k * MDataCluster];
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
reg_b[0] = p_b[k * NDataCluster];
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
reg_b[1] = p_b[k * NDataCluster + NThreadCluster];
|
||||
reg_a[1] = p_a[k * MDataCluster + MThreadCluster];
|
||||
outerProduct4x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2], reg_c[4], reg_c[6]);
|
||||
outerProduct4x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3], reg_c[5], reg_c[7]);
|
||||
}
|
||||
outerProduct4x4(reg_a[1], reg_b[0], reg_c[8], reg_c[10], reg_c[12], reg_c[14]);
|
||||
outerProduct4x4(reg_a[1], reg_b[1], reg_c[9], reg_c[11], reg_c[13], reg_c[15]);
|
||||
}
|
||||
else if(MPerThreadSubC == 2 && NPerThreadSubC == 4 && MRepeat == 2 && NRepeat == 2 &&
|
||||
KPerThreadLoop == 1)
|
||||
{
|
||||
using float2_type = vector_type<float, 2>::MemoryType;
|
||||
using float4_type = vector_type<float, 4>::MemoryType;
|
||||
|
||||
float2_type* reg_a = reinterpret_cast<float2_type*>(p_a_thread);
|
||||
float4_type* reg_b = reinterpret_cast<float4_type*>(p_b_thread);
|
||||
float4_type* reg_c = reinterpret_cast<float4_type*>(p_c_thread);
|
||||
|
||||
const float2_type* p_a =
|
||||
reinterpret_cast<const float2_type*>(&p_a_block[mMyThreadOffsetA]);
|
||||
const float4_type* p_b =
|
||||
reinterpret_cast<const float4_type*>(&p_b_block[mMyThreadOffsetB]);
|
||||
|
||||
reg_a[0] = p_a[0];
|
||||
reg_b[0] = p_b[0];
|
||||
reg_b[1] = p_b[NThreadCluster];
|
||||
reg_a[1] = p_a[MThreadCluster];
|
||||
outerProduct2x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2]);
|
||||
outerProduct2x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3]);
|
||||
#pragma unroll
|
||||
for(index_t k = 1; k < K; ++k)
|
||||
{
|
||||
reg_a[0] = p_a[k * MDataCluster];
|
||||
outerProduct2x4(reg_a[1], reg_b[0], reg_c[4], reg_c[6]);
|
||||
reg_b[0] = p_b[k * NDataCluster];
|
||||
outerProduct2x4(reg_a[1], reg_b[1], reg_c[5], reg_c[7]);
|
||||
reg_b[1] = p_b[k * NDataCluster + NThreadCluster];
|
||||
reg_a[1] = p_a[k * MDataCluster + MThreadCluster];
|
||||
outerProduct2x4(reg_a[0], reg_b[0], reg_c[0], reg_c[2]);
|
||||
outerProduct2x4(reg_a[0], reg_b[1], reg_c[1], reg_c[3]);
|
||||
}
|
||||
outerProduct2x4(reg_a[1], reg_b[0], reg_c[4], reg_c[6]);
|
||||
outerProduct2x4(reg_a[1], reg_b[1], reg_c[5], reg_c[7]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run_source(const FloatA* const __restrict__ p_a_block,
|
||||
const FloatB* const __restrict__ p_b_block,
|
||||
FloatC* const __restrict__ p_c_thread) const
|
||||
{
|
||||
constexpr auto True = integral_constant<bool, true>{};
|
||||
constexpr auto False = integral_constant<bool, false>{};
|
||||
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub for copy
|
||||
constexpr auto a_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_sub_mtx = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
@@ -352,63 +136,211 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
// thread A, B for GEMM
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
|
||||
decltype(a_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
MPerThreadSubC,
|
||||
DataPerReadA>{};
|
||||
|
||||
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
|
||||
decltype(b_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
NPerThreadSubC,
|
||||
DataPerReadB>{};
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_mtx),
|
||||
decltype(b_thread_mtx),
|
||||
decltype(c_thread_mtx)>{};
|
||||
#pragma unroll
|
||||
// loop over k
|
||||
for(index_t k_begin = 0; k_begin < K; k_begin += KPerThreadLoop)
|
||||
{
|
||||
#pragma unroll
|
||||
// copy A-sub to form A
|
||||
// read A
|
||||
for(index_t m_repeat = 0; m_repeat < MRepeat; ++m_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
a_block_mtx,
|
||||
p_a_block +
|
||||
a_block_mtx.GetOffsetFromMultiIndex(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
a_thread_copy.Run(
|
||||
p_a_block + a_block_mtx.CalculateOffset(k_begin, m_repeat * MPerLevel1Cluster) +
|
||||
mMyThreadOffsetA,
|
||||
a_thread_mtx,
|
||||
p_a_thread + a_thread_mtx.GetOffsetFromMultiIndex(0, m_repeat * MPerThreadSubC),
|
||||
a_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadA>{});
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(0, m_repeat * MPerThreadSubC));
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
// copy B-sub to form B
|
||||
// read B
|
||||
for(index_t n_repeat = 0; n_repeat < NRepeat; ++n_repeat)
|
||||
{
|
||||
threadwise_matrix_copy(
|
||||
b_block_mtx,
|
||||
p_b_block +
|
||||
b_block_mtx.GetOffsetFromMultiIndex(k_begin, n_repeat * NPerLevel1Cluster) +
|
||||
b_thread_copy.Run(
|
||||
p_b_block + b_block_mtx.CalculateOffset(k_begin, n_repeat * NPerLevel1Cluster) +
|
||||
mMyThreadOffsetB,
|
||||
b_thread_mtx,
|
||||
p_b_thread + b_thread_mtx.GetOffsetFromMultiIndex(0, n_repeat * NPerThreadSubC),
|
||||
b_thread_sub_mtx.GetLengths(),
|
||||
Number<DataPerReadB>{});
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, n_repeat * NPerThreadSubC));
|
||||
}
|
||||
|
||||
// C = A * B
|
||||
threadwise_gemm(a_thread_mtx,
|
||||
True,
|
||||
p_a_thread,
|
||||
b_thread_mtx,
|
||||
False,
|
||||
p_b_thread,
|
||||
c_thread_mtx,
|
||||
False,
|
||||
p_c_thread);
|
||||
// C += A * B
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
}
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC>
|
||||
__device__ void Run(const FloatA* __restrict__ p_a_block,
|
||||
const FloatB* __restrict__ p_b_block,
|
||||
FloatC* __restrict__ p_c_thread) const
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void
|
||||
Run_pipelined_2x2(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
{
|
||||
#if CK_USE_AMD_INLINE_ASM && CK_BLOCKWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
Run_amd_asm_v2(p_a_block, p_b_block, p_c_thread);
|
||||
constexpr auto a_block_mtx = BlockMatrixA{};
|
||||
constexpr auto b_block_mtx = BlockMatrixB{};
|
||||
constexpr auto c_thread_mtx = ThreadMatrixC{};
|
||||
|
||||
constexpr index_t K = a_block_mtx.NRow();
|
||||
|
||||
constexpr index_t MPerThread = c_thread_mtx.NRow();
|
||||
constexpr index_t NPerThread = c_thread_mtx.NCol();
|
||||
|
||||
constexpr index_t MPerLevel1Cluster =
|
||||
MPerThreadSubC * MLevel0ThreadCluster * MLevel1ThreadCluster;
|
||||
constexpr index_t NPerLevel1Cluster =
|
||||
NPerThreadSubC * NLevel0ThreadCluster * NLevel1ThreadCluster;
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_assert(MRepeat == 2 && NRepeat == 2,
|
||||
"wrong! inline asm cannot deal with this GEMM config yet");
|
||||
|
||||
// thread A, B
|
||||
constexpr auto a_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<MPerThread>{});
|
||||
constexpr auto b_thread_mtx =
|
||||
make_ConstantMatrixDescriptor_packed(Number<KPerThreadLoop>{}, Number<NPerThread>{});
|
||||
|
||||
// thread A-sub, B-sub
|
||||
constexpr auto a_thread_sub_mtx = a_thread_mtx.MakeSubMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<MPerThreadSubC>{});
|
||||
constexpr auto b_thread_sub_mtx = b_thread_mtx.MakeSubMatrixDescriptor(
|
||||
Number<KPerThreadLoop>{}, Number<NPerThreadSubC>{});
|
||||
|
||||
// thread C-sub
|
||||
constexpr auto c_thread_sub_mtx = ThreadMatrixC::MakeSubMatrixDescriptor(
|
||||
Number<MPerThreadSubC>{}, Number<NPerThreadSubC>{});
|
||||
|
||||
FloatA p_a_thread[a_thread_mtx.GetElementSpace()];
|
||||
FloatB p_b_thread[b_thread_mtx.GetElementSpace()];
|
||||
|
||||
constexpr auto a_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixA,
|
||||
decltype(a_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
MPerThreadSubC,
|
||||
DataPerReadA>{};
|
||||
|
||||
constexpr auto b_thread_copy = ThreadwiseMatrixSliceCopy<BlockMatrixB,
|
||||
decltype(b_thread_mtx),
|
||||
KPerThreadLoop,
|
||||
NPerThreadSubC,
|
||||
DataPerReadB>{};
|
||||
|
||||
constexpr auto threadwise_gemm =
|
||||
ThreadwiseGemmTransANormalBNormalC<decltype(a_thread_sub_mtx),
|
||||
decltype(b_thread_sub_mtx),
|
||||
decltype(c_thread_sub_mtx)>{};
|
||||
|
||||
const FloatA* p_a_block_off = p_a_block + mMyThreadOffsetA;
|
||||
const FloatB* p_b_block_off = p_b_block + mMyThreadOffsetB;
|
||||
|
||||
// read A_sub_0
|
||||
a_thread_copy.Run(p_a_block_off, p_a_thread);
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy.Run(p_b_block_off, p_b_thread);
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(0, NPerLevel1Cluster),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(0, MPerLevel1Cluster),
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(p_a_thread,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
|
||||
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
|
||||
|
||||
#pragma unroll
|
||||
// loop over rest of k
|
||||
for(index_t k = KPerThreadLoop; k < K; k += KPerThreadLoop)
|
||||
{
|
||||
// read A_sub_0
|
||||
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, 0), p_a_thread);
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
|
||||
p_b_thread,
|
||||
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
|
||||
|
||||
// read B_sub_0
|
||||
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, 0), p_b_thread);
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
|
||||
p_c_thread +
|
||||
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
|
||||
|
||||
// read B_sub_1
|
||||
b_thread_copy.Run(p_b_block_off + b_block_mtx.CalculateOffset(k, NPerLevel1Cluster),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC));
|
||||
|
||||
// read A_sub_1
|
||||
a_thread_copy.Run(p_a_block_off + a_block_mtx.CalculateOffset(k, MPerLevel1Cluster),
|
||||
p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC));
|
||||
|
||||
// C_sub_00 += transpose(A_sub_0) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread, p_b_thread, p_c_thread);
|
||||
|
||||
// C_sub_01 += transpose(A_sub_0) * B_sub_1
|
||||
threadwise_gemm.Run(p_a_thread,
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
|
||||
p_c_thread + ThreadMatrixC::CalculateOffset(0, NPerThreadSubC));
|
||||
}
|
||||
|
||||
// C_sub_10 += transpose(A_sub_1) * B_sub_0
|
||||
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
|
||||
p_b_thread,
|
||||
p_c_thread + ThreadMatrixC::CalculateOffset(MPerThreadSubC, 0));
|
||||
|
||||
// C_sub_11 += transpose(A_sub_1) * B_sub_1
|
||||
threadwise_gemm.Run(p_a_thread + a_thread_mtx.CalculateOffset(0, MPerThreadSubC),
|
||||
p_b_thread + b_thread_mtx.CalculateOffset(0, NPerThreadSubC),
|
||||
p_c_thread +
|
||||
ThreadMatrixC::CalculateOffset(MPerThreadSubC, NPerThreadSubC));
|
||||
}
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ void Run(const FloatA* p_a_block, const FloatB* p_b_block, FloatC* p_c_thread) const
|
||||
{
|
||||
#if CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE
|
||||
constexpr index_t MPerThread = ThreadMatrixC::NRow();
|
||||
constexpr index_t NPerThread = ThreadMatrixC::NCol();
|
||||
|
||||
constexpr index_t MRepeat = MPerThread / MPerThreadSubC;
|
||||
constexpr index_t NRepeat = NPerThread / NPerThreadSubC;
|
||||
|
||||
static_if<MRepeat == 2 && NRepeat == 2>{}([&](auto) {
|
||||
Run_pipelined_2x2(p_a_block, p_b_block, p_c_thread);
|
||||
}).Else([&](auto) { Run_naive(p_a_block, p_b_block, p_c_thread); });
|
||||
#else
|
||||
Run_source(p_a_block, p_b_block, p_c_thread);
|
||||
Run_naive(p_a_block, p_b_block, p_c_thread);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -68,64 +68,118 @@ struct BlockwiseGenericTensorSliceCopy_v4
|
||||
|
||||
template <typename BlockSrcData,
|
||||
typename ThreadBufferData,
|
||||
address_space_t BlockSrcAddressSpace = address_space_t::generic,
|
||||
address_space_t ThreadBufferAddressSpace = address_space_t::generic>
|
||||
AddressSpace BlockSrcAddressSpace,
|
||||
AddressSpace ThreadBufferAddressSpace>
|
||||
__device__ void
|
||||
RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer,
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace>,
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>) const
|
||||
{
|
||||
constexpr auto block_src_address_space =
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace>{};
|
||||
constexpr auto thread_buffer_address_space =
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
|
||||
|
||||
constexpr bool has_optimized_address_calculation =
|
||||
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
|
||||
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
{
|
||||
mThreadwiseLoad.Run_optimized_src_address_calculation(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseLoad.Run(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename ThreadBufferData>
|
||||
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer) const
|
||||
{
|
||||
#if 1
|
||||
mThreadwiseLoad.template Run<BlockSrcData,
|
||||
ThreadBufferData,
|
||||
BlockSrcAddressSpace,
|
||||
ThreadBufferAddressSpace>(p_block_src, p_thread_buffer);
|
||||
#else // tweaking
|
||||
mThreadwiseLoad.template Run_optimized_src_address_calculation<BlockSrcData,
|
||||
ThreadBufferData,
|
||||
BlockSrcAddressSpace,
|
||||
ThreadBufferAddressSpace>(
|
||||
p_block_src, p_thread_buffer);
|
||||
#endif
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename ThreadBufferData,
|
||||
typename BlockDstData,
|
||||
address_space_t ThreadBufferAddressSpace = address_space_t::generic,
|
||||
address_space_t BlockDstAddressSpace = address_space_t::generic>
|
||||
AddressSpace ThreadBufferAddressSpace,
|
||||
AddressSpace BlockDstAddressSpace>
|
||||
__device__ void
|
||||
RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst,
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>,
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace>) const
|
||||
{
|
||||
constexpr auto thread_buffer_address_space =
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
|
||||
constexpr auto block_dst_address_space =
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace>{};
|
||||
|
||||
constexpr bool has_optimized_address_calculation =
|
||||
decltype(mThreadwiseStore)::HasWorkingOptimizedAddressCalculation();
|
||||
|
||||
// TODO: threadwise copy is still being tweaked
|
||||
if(has_optimized_address_calculation)
|
||||
{
|
||||
mThreadwiseStore.Run_optimized_dst_address_calculation(
|
||||
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
|
||||
}
|
||||
else
|
||||
{
|
||||
mThreadwiseStore.Run(
|
||||
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ThreadBufferData, typename BlockDstData>
|
||||
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst) const
|
||||
{
|
||||
#if 1
|
||||
mThreadwiseStore.template Run<ThreadBufferData,
|
||||
BlockDstData,
|
||||
ThreadBufferAddressSpace,
|
||||
BlockDstAddressSpace>(p_thread_buffer, p_block_dst);
|
||||
#else // tweaking
|
||||
mThreadwiseStore.template Run_optimized_dst_address_calculation<ThreadBufferData,
|
||||
BlockDstData,
|
||||
ThreadBufferAddressSpace,
|
||||
BlockDstAddressSpace>(
|
||||
p_thread_buffer, p_block_dst);
|
||||
#endif
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename BlockSrcData,
|
||||
typename BlockDstData,
|
||||
address_space_t BlockSrcAddressSpace = address_space_t::generic,
|
||||
address_space_t BlockDstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
AddressSpace BlockSrcAddressSpace,
|
||||
AddressSpace BlockDstAddressSpace>
|
||||
__device__ void
|
||||
Run(const BlockSrcData* p_block_src,
|
||||
BlockDstData* p_block_dst,
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace> block_src_address_space,
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace> block_dst_address_space) const
|
||||
{
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
RunLoadThreadBuffer<BlockSrcData,
|
||||
BlockSrcData,
|
||||
BlockSrcAddressSpace,
|
||||
address_space_t::generic>(p_block_src, p_thread_buffer);
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
|
||||
|
||||
// if there is type conversion, it's done during store
|
||||
RunStoreThreadBuffer<BlockSrcData,
|
||||
BlockDstData,
|
||||
address_space_t::generic,
|
||||
BlockDstAddressSpace>(p_thread_buffer, p_block_dst);
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, block_dst_address_space);
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename BlockDstData>
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
|
||||
@@ -2,15 +2,11 @@
|
||||
#define CK_BLOCKWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "tensor_coordinate_deprecated.hpp"
|
||||
#include "threadwise_generic_tensor_slice_copy_deprecated.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
// Slice a (normal or merged) tensor, and copy it into another (normal or merged) tensor
|
||||
@@ -20,7 +16,7 @@ namespace ck {
|
||||
// that, on a merged dimension that constains multiple original dimensions, the length of
|
||||
// the last original dimension need to be evenly dividable by its sub-lengths. Also, the
|
||||
// repeat-length on the merged dimension need to be 1. These sanity checks are performed
|
||||
// in constructor of BlockwiseGenericTensorSliceCopy_v1
|
||||
// in constructor of BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
template <index_t BlockSize,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
@@ -34,7 +30,7 @@ template <index_t BlockSize,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v1
|
||||
struct BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
@@ -62,7 +58,8 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
Array<index_t, nOriginalDimSrc> mThreadSrcOriginalMultiId;
|
||||
Array<index_t, nOriginalDimDst> mThreadDstOriginalMultiId;
|
||||
|
||||
__device__ BlockwiseGenericTensorSliceCopy_v1(Array<index_t, nDim> src_block_data_id_begin,
|
||||
__device__
|
||||
BlockwiseGenericTensorSliceCopy_v1_deprecated(Array<index_t, nDim> src_block_data_id_begin,
|
||||
Array<index_t, nDim> dst_block_data_id_begin)
|
||||
{
|
||||
// check NDim consistency
|
||||
@@ -196,14 +193,14 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
return make_ConstantTensorDescriptor_packed(SubLengths{} * repeat_lengths);
|
||||
}
|
||||
|
||||
__device__ static constexpr index_t GetRegisterBufferSize()
|
||||
__device__ static constexpr index_t GetThreadBufferSize()
|
||||
{
|
||||
return GetRegisterBufferDescriptor().GetElementSpace();
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
__device__ void RunLoadRegisterBuffer(const TData* __restrict__ p_src,
|
||||
TData* __restrict__ p_buffer) const
|
||||
__device__ void RunLoadThreadBuffer(const TData* __restrict__ p_src,
|
||||
TData* __restrict__ p_buffer) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SubLengths{};
|
||||
|
||||
@@ -244,22 +241,22 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
// that constains multiple original dimensions, the length of the last original
|
||||
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
|
||||
// the merged dimension need to be 1. These sanity checks are performed in constructor
|
||||
// of BlockwiseGenericTensorSliceCopy_v1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<SrcDesc,
|
||||
decltype(thread_buffer_desc),
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
// of BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<SrcDesc,
|
||||
decltype(thread_buffer_desc),
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
.Run(p_src + src_offset + mThreadSrcOffset, p_buffer + buffer_offset);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
__device__ void RunStoreRegisterBuffer(const TData* __restrict__ p_buffer,
|
||||
TData* __restrict__ p_dst) const
|
||||
__device__ void RunStoreThreadBuffer(const TData* __restrict__ p_buffer,
|
||||
TData* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto thread_sub_tensor_lengths = SubLengths{};
|
||||
|
||||
@@ -299,14 +296,14 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
// that constains multiple original dimensions, the length of the last original
|
||||
// dimension need to be evenly dividable by its sub-lengths. Also, the repeat-length on
|
||||
// the merged dimension need to be 1. These sanity checks are performed in constructor
|
||||
// of BlockwiseGenericTensorSliceCopy_v1
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2<decltype(thread_buffer_desc),
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>(
|
||||
// of BlockwiseGenericTensorSliceCopy_v1_deprecated
|
||||
ThreadwiseGenericTensorSliceCopy_v1r2_deprecated<decltype(thread_buffer_desc),
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>(
|
||||
make_zero_array<index_t, nDim>(), make_zero_array<index_t, nDim>())
|
||||
.Run(p_buffer + buffer_offset, p_dst + dst_offset + mThreadDstOffset);
|
||||
});
|
||||
@@ -315,10 +312,10 @@ struct BlockwiseGenericTensorSliceCopy_v1
|
||||
template <typename TData>
|
||||
__device__ void Run(const TData* __restrict__ p_src, TData* __restrict__ p_dst) const
|
||||
{
|
||||
TData p_buffer[GetRegisterBufferSize()];
|
||||
TData p_buffer[GetThreadBufferSize()];
|
||||
|
||||
RunLoadRegisterBuffer(p_src, p_buffer);
|
||||
RunStoreRegisterBuffer(p_buffer, p_dst);
|
||||
RunLoadThreadBuffer(p_src, p_buffer);
|
||||
RunStoreThreadBuffer(p_buffer, p_dst);
|
||||
}
|
||||
|
||||
// When moving the slicing windows along a merged dimension, if the strides of the
|
||||
@@ -432,14 +429,14 @@ template <index_t BlockSize,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct BlockwiseGenericTensorSliceCopy_v2
|
||||
struct BlockwiseGenericTensorSliceCopy_v2_deprecated
|
||||
{
|
||||
static constexpr index_t nDim = SrcDesc::GetNumOfDimension();
|
||||
|
||||
using Index = MultiIndex<nDim>;
|
||||
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2(const Index& src_block_slice_origin,
|
||||
const Index& dst_block_slice_origin)
|
||||
__device__ constexpr BlockwiseGenericTensorSliceCopy_v2_deprecated(
|
||||
const Index& src_block_slice_origin, const Index& dst_block_slice_origin)
|
||||
{
|
||||
static_assert(
|
||||
nDim == SrcDesc::GetNumOfDimension() && nDim == DstDesc::GetNumOfDimension() &&
|
||||
@@ -478,42 +475,96 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
return ThreadBufferDesc::GetElementSpace();
|
||||
}
|
||||
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
address_space_t BlockSrcAddressSpace = address_space_t::generic,
|
||||
address_space_t ThreadBufferAddressSpace = address_space_t::generic>
|
||||
__device__ void RunLoadThreadBuffer(const SrcData* p_block_src, DstData* p_thread_buffer) const
|
||||
template <typename BlockSrcData,
|
||||
typename ThreadBufferData,
|
||||
AddressSpace BlockSrcAddressSpace,
|
||||
AddressSpace ThreadBufferAddressSpace>
|
||||
__device__ void
|
||||
RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer,
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace>,
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>) const
|
||||
{
|
||||
mThreadwiseLoad
|
||||
.template Run<SrcData, DstData, BlockSrcAddressSpace, ThreadBufferAddressSpace>(
|
||||
p_block_src, p_thread_buffer);
|
||||
constexpr auto block_src_address_space =
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace>{};
|
||||
constexpr auto thread_buffer_address_space =
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
|
||||
|
||||
mThreadwiseLoad.Run(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, thread_buffer_address_space);
|
||||
}
|
||||
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
address_space_t ThreadBufferAddressSpace = address_space_t::generic,
|
||||
address_space_t BlockDstAddressSpace = address_space_t::generic>
|
||||
__device__ void RunStoreThreadBuffer(const SrcData* p_thread_buffer, DstData* p_block_dst) const
|
||||
template <typename BlockSrcData, typename ThreadBufferData>
|
||||
__device__ void RunLoadThreadBuffer(const BlockSrcData* p_block_src,
|
||||
ThreadBufferData* p_thread_buffer) const
|
||||
{
|
||||
mThreadwiseStore
|
||||
.template Run<SrcData, DstData, ThreadBufferAddressSpace, BlockDstAddressSpace>(
|
||||
p_thread_buffer, p_block_dst);
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
address_space_t BlockSrcAddressSpace = address_space_t::generic,
|
||||
address_space_t BlockDstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run(const SrcData* p_block_src, DstData* p_block_dst) const
|
||||
template <typename ThreadBufferData,
|
||||
typename BlockDstData,
|
||||
AddressSpace ThreadBufferAddressSpace,
|
||||
AddressSpace BlockDstAddressSpace>
|
||||
__device__ void
|
||||
RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst,
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>,
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace>) const
|
||||
{
|
||||
SrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
constexpr auto thread_buffer_address_space =
|
||||
integral_constant<AddressSpace, ThreadBufferAddressSpace>{};
|
||||
constexpr auto block_dst_address_space =
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace>{};
|
||||
|
||||
RunLoadThreadBuffer<SrcData, SrcData, BlockSrcAddressSpace, address_space_t::generic>(
|
||||
p_block_src, p_thread_buffer);
|
||||
mThreadwiseStore.Run(
|
||||
p_thread_buffer, p_block_dst, thread_buffer_address_space, block_dst_address_space);
|
||||
}
|
||||
|
||||
template <typename ThreadBufferData, typename BlockDstData>
|
||||
__device__ void RunStoreThreadBuffer(const ThreadBufferData* p_thread_buffer,
|
||||
BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename BlockSrcData,
|
||||
typename BlockDstData,
|
||||
AddressSpace BlockSrcAddressSpace,
|
||||
AddressSpace BlockDstAddressSpace>
|
||||
__device__ void
|
||||
Run(const BlockSrcData* p_block_src,
|
||||
BlockDstData* p_block_dst,
|
||||
integral_constant<AddressSpace, BlockSrcAddressSpace> block_src_address_space,
|
||||
integral_constant<AddressSpace, BlockDstAddressSpace> block_dst_address_space) const
|
||||
{
|
||||
BlockSrcData p_thread_buffer[GetThreadBufferSize()];
|
||||
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
RunLoadThreadBuffer(
|
||||
p_block_src, p_thread_buffer, block_src_address_space, generic_address_space);
|
||||
|
||||
// if there is type conversion, it's done during store
|
||||
RunStoreThreadBuffer<SrcData, DstData, address_space_t::generic, BlockDstAddressSpace>(
|
||||
p_thread_buffer, p_block_dst);
|
||||
RunStoreThreadBuffer(
|
||||
p_thread_buffer, p_block_dst, generic_address_space, block_dst_address_space);
|
||||
}
|
||||
|
||||
template <typename BlockSrcData, typename BlockDstData>
|
||||
__device__ void Run(const BlockSrcData* p_block_src, BlockDstData* p_block_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
Run(p_block_src, p_block_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
@@ -533,25 +584,25 @@ struct BlockwiseGenericTensorSliceCopy_v2
|
||||
private:
|
||||
using ThreadBufferDesc = decltype(make_ConstantTensorDescriptor_packed(SubLengths{}));
|
||||
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v2r1<SrcDesc,
|
||||
ThreadBufferDesc,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
using ThreadwiseLoad = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<SrcDesc,
|
||||
ThreadBufferDesc,
|
||||
SubLengths,
|
||||
SrcDimAccessOrder,
|
||||
SrcDimAccessOrder,
|
||||
SrcVectorAccessDim,
|
||||
SrcVectorAccessDim,
|
||||
SrcDataPerAccess,
|
||||
1>;
|
||||
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v2r1<ThreadBufferDesc,
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
using ThreadwiseStore = ThreadwiseGenericTensorSliceCopy_v2r1_deprecated<ThreadBufferDesc,
|
||||
DstDesc,
|
||||
SubLengths,
|
||||
DstDimAccessOrder,
|
||||
DstDimAccessOrder,
|
||||
DstVectorAccessDim,
|
||||
DstVectorAccessDim,
|
||||
1,
|
||||
DstDataPerAccess>;
|
||||
|
||||
ThreadwiseLoad mThreadwiseLoad;
|
||||
ThreadwiseStore mThreadwiseStore;
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
#define CK_THREADWISE_DIRECT_CONVOLUTION_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "threadwise_tensor_slice_copy.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -3,102 +3,164 @@
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantMatrixDescriptor.hpp"
|
||||
#include "math.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class Float, class Matrix>
|
||||
template <typename Float, class Matrix>
|
||||
__device__ void threadwise_matrix_set_zero(Matrix, Float* __restrict__ p_thread)
|
||||
{
|
||||
for(index_t i = 0; i < Matrix::NRow(); ++i)
|
||||
{
|
||||
for(index_t j = 0; j < Matrix::NCol(); ++j)
|
||||
{
|
||||
const index_t id = Matrix::GetOffsetFromMultiIndex(i, j);
|
||||
const index_t id = Matrix::CalculateOffset(i, j);
|
||||
p_thread[id] = Float(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class Float,
|
||||
class SrcMatrix,
|
||||
class DstMatrix,
|
||||
index_t NRow,
|
||||
index_t NCol,
|
||||
index_t DataPerRead>
|
||||
__device__ void threadwise_matrix_copy(SrcMatrix,
|
||||
const Float* __restrict__ p_src,
|
||||
DstMatrix,
|
||||
Float* __restrict__ p_dst,
|
||||
Sequence<NRow, NCol>,
|
||||
Number<DataPerRead>)
|
||||
template <typename SrcMatrix,
|
||||
typename DstMatrix,
|
||||
index_t NSliceRow,
|
||||
index_t NSliceCol,
|
||||
index_t DataPerAccess>
|
||||
struct ThreadwiseMatrixSliceCopy
|
||||
{
|
||||
static_assert(NCol % DataPerRead == 0, "wrong! should be NCol % == DataPerRead == 0");
|
||||
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
constexpr auto src_mtx = SrcMatrix{};
|
||||
constexpr auto dst_mtx = DstMatrix{};
|
||||
|
||||
for(index_t i = 0; i < NRow; ++i)
|
||||
__device__ constexpr ThreadwiseMatrixSliceCopy()
|
||||
{
|
||||
for(index_t j = 0; j < NCol; j += DataPerRead)
|
||||
{
|
||||
const index_t src_index = src_mtx.GetOffsetFromMultiIndex(i, j);
|
||||
const index_t dst_index = dst_mtx.GetOffsetFromMultiIndex(i, j);
|
||||
static_assert(SrcMatrix::RowStride() % DataPerAccess == 0 &&
|
||||
DstMatrix::RowStride() % DataPerAccess == 0,
|
||||
"wrong! wrong alignment");
|
||||
static_assert(NSliceCol % DataPerAccess == 0,
|
||||
"wrong! should be NSliceCol % DataPerAccess == 0");
|
||||
}
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
template <typename Data>
|
||||
__device__ static void Run(const Data* p_src, Data* p_dst)
|
||||
{
|
||||
using vector_t = typename vector_type<Data, DataPerAccess>::MemoryType;
|
||||
|
||||
for(index_t i = 0; i < NSliceRow; ++i)
|
||||
{
|
||||
for(index_t j = 0; j < NSliceCol; j += DataPerAccess)
|
||||
{
|
||||
const index_t src_index = SrcMatrix::CalculateOffset(i, j);
|
||||
const index_t dst_index = DstMatrix::CalculateOffset(i, j);
|
||||
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_index]) =
|
||||
*reinterpret_cast<const vector_t*>(&p_src[src_index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class MatrixA,
|
||||
class MatrixB,
|
||||
class MatrixC,
|
||||
bool TransA,
|
||||
bool TransB,
|
||||
bool TransC,
|
||||
class FloatA,
|
||||
class FloatB,
|
||||
class FloatC>
|
||||
__device__ void threadwise_gemm(MatrixA,
|
||||
integral_constant<bool, TransA>,
|
||||
const FloatA* __restrict__ p_a_thread,
|
||||
MatrixB,
|
||||
integral_constant<bool, TransB>,
|
||||
const FloatB* __restrict__ p_b_thread,
|
||||
MatrixC,
|
||||
integral_constant<bool, TransC>,
|
||||
FloatC* __restrict__ p_c_thread)
|
||||
// C += transpose(A) * B
|
||||
// Element of matrix can be vectorized data
|
||||
template <typename MatrixA, typename MatrixB, typename MatrixC>
|
||||
struct ThreadwiseGemmTransANormalBNormalC
|
||||
{
|
||||
static_if<TransA && (!TransB) && (!TransC)>{}([&](auto) {
|
||||
constexpr auto a_mtx = MatrixA{};
|
||||
constexpr auto b_mtx = MatrixB{};
|
||||
constexpr auto c_mtx = MatrixC{};
|
||||
__device__ constexpr ThreadwiseGemmTransANormalBNormalC()
|
||||
{
|
||||
static_assert(MatrixA::NRow() == MatrixB::NRow() && MatrixA::NCol() == MatrixC::NRow() &&
|
||||
MatrixB::NCol() == MatrixC::NCol(),
|
||||
"wrong!");
|
||||
}
|
||||
|
||||
constexpr index_t M = c_mtx.NRow();
|
||||
constexpr index_t N = c_mtx.NCol();
|
||||
constexpr index_t K = a_mtx.NRow(); // A is transposed
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run_source(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
constexpr index_t M = MatrixC::NRow();
|
||||
constexpr index_t N = MatrixC::NCol();
|
||||
constexpr index_t K = MatrixA::NRow(); // A is transposed
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t i = 0; i < M; ++i)
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
for(index_t j = 0; j < N; ++j)
|
||||
for(index_t n = 0; n < N; ++n)
|
||||
{
|
||||
const index_t aindex = a_mtx.GetOffsetFromMultiIndex(k, i); // A is transposed
|
||||
const index_t bindex = b_mtx.GetOffsetFromMultiIndex(k, j);
|
||||
const index_t cindex = c_mtx.GetOffsetFromMultiIndex(i, j);
|
||||
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
|
||||
const index_t bindex = MatrixB::CalculateOffset(k, n);
|
||||
const index_t cindex = MatrixC::CalculateOffset(m, n);
|
||||
|
||||
p_c_thread[cindex] += p_a_thread[aindex] * p_b_thread[bindex];
|
||||
p_c[cindex] +=
|
||||
inner_product_with_conversion<FloatC>{}(p_a[aindex], p_b[bindex]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}).Else([&](auto fwd) {
|
||||
// not implemented
|
||||
static_assert(fwd(false), "wrong! support for this config is not implemented");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run_amd_asm(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
constexpr index_t M = MatrixC::NRow();
|
||||
constexpr index_t N = MatrixC::NCol();
|
||||
constexpr index_t K = MatrixA::NRow(); // A is transposed
|
||||
|
||||
static_assert(N == 4 || N == 2, "wrong! this config not supported by asm yet");
|
||||
|
||||
for(index_t k = 0; k < K; ++k)
|
||||
{
|
||||
for(index_t m = 0; m < M; ++m)
|
||||
{
|
||||
const index_t aindex = MatrixA::CalculateOffset(k, m); // A is transposed
|
||||
|
||||
static_if<N == 2>{}([&](auto) {
|
||||
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0);
|
||||
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1);
|
||||
|
||||
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0);
|
||||
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1);
|
||||
|
||||
__outer_product_1x2(
|
||||
p_a[aindex], p_b[bindex_0], p_b[bindex_1], p_c[cindex_0], p_c[cindex_1]);
|
||||
});
|
||||
|
||||
static_if<N == 4>{}([&](auto) {
|
||||
const index_t bindex_0 = MatrixB::CalculateOffset(k, 0);
|
||||
const index_t bindex_1 = MatrixB::CalculateOffset(k, 1);
|
||||
const index_t bindex_2 = MatrixB::CalculateOffset(k, 2);
|
||||
const index_t bindex_3 = MatrixB::CalculateOffset(k, 3);
|
||||
|
||||
const index_t cindex_0 = MatrixC::CalculateOffset(m, 0);
|
||||
const index_t cindex_1 = MatrixC::CalculateOffset(m, 1);
|
||||
const index_t cindex_2 = MatrixC::CalculateOffset(m, 2);
|
||||
const index_t cindex_3 = MatrixC::CalculateOffset(m, 3);
|
||||
|
||||
__outer_product_1x4(p_a[aindex],
|
||||
p_b[bindex_0],
|
||||
p_b[bindex_1],
|
||||
p_b[bindex_2],
|
||||
p_b[bindex_3],
|
||||
p_c[cindex_0],
|
||||
p_c[cindex_1],
|
||||
p_c[cindex_2],
|
||||
p_c[cindex_3]);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename FloatA, typename FloatB, typename FloatC>
|
||||
__device__ static void Run(const FloatA* p_a, const FloatB* p_b, FloatC* p_c)
|
||||
{
|
||||
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
constexpr bool has_amd_asm = is_same<FloatC, float>{} &&
|
||||
((is_same<FloatA, float>{} && is_same<FloatB, float>{}) ||
|
||||
(is_same<FloatA, half2_t>{} && is_same<FloatB, half2_t>{}) ||
|
||||
(is_same<FloatA, half4_t>{} && is_same<FloatB, half4_t>{}));
|
||||
|
||||
static_if<has_amd_asm>{}([&](auto fwd) {
|
||||
Run_amd_asm(p_a, p_b, fwd(p_c));
|
||||
}).Else([&](auto) { Run_source(p_a, p_b, p_c); });
|
||||
#else
|
||||
Run_source(p_a, p_b, p_c);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
#define CK_THREADWISE_GENERIC_TENSOR_OP_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
|
||||
namespace ck {
|
||||
template <class Float, class TDesc>
|
||||
|
||||
@@ -6,14 +6,6 @@
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "tensor_coordinate.hpp"
|
||||
|
||||
#ifndef CK_USE_AMD_INTRINSIC
|
||||
#define CK_USE_AMD_INTRINSIC 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This version use multi-index transformation
|
||||
@@ -76,9 +68,12 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
address_space_t SrcAddressSpace = address_space_t::generic,
|
||||
address_space_t DstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const SrcData* p_src,
|
||||
DstData* p_dst,
|
||||
integral_constant<AddressSpace, SrcAddressSpace>,
|
||||
integral_constant<AddressSpace, DstAddressSpace>) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
@@ -122,15 +117,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
|
||||
// Check src vector's padding situation, only check the first data in this src
|
||||
// vector. It's user's responsiblity to make sure all data in the src vector
|
||||
// has
|
||||
// the same padding situation
|
||||
// has the same padding situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
__buffer_load<SrcData, SrcDataPerAccess>(
|
||||
p_src, src_coord.GetOffset(), 0);
|
||||
fwd(p_src), src_coord.GetOffset(), 0);
|
||||
#else
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_coord.GetOffset()]);
|
||||
@@ -163,15 +157,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
|
||||
// Check dst vector's padding situation, only check the first data in this dst
|
||||
// vector. It's user's responsiblity to make sure all data in the dst vector
|
||||
// has
|
||||
// the same padding situation
|
||||
// has the same padding situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
static_if<DstAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
__buffer_store<DstData, DstDataPerAccess>(
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]),
|
||||
p_dst,
|
||||
fwd(p_dst),
|
||||
dst_coord.GetOffset(),
|
||||
0);
|
||||
#else
|
||||
@@ -188,6 +181,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
});
|
||||
}
|
||||
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
Run(p_src, p_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
// Modify Length to 1, if Mask is set to false
|
||||
// Used for isolating linear dimension from non-linear dimensions
|
||||
template <index_t... Lengths, index_t... Mask>
|
||||
@@ -202,12 +204,16 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
// This version is optimized for address calculation of src tensor
|
||||
// TODO: this function is not compiled to expected ISA
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
address_space_t SrcAddressSpace = address_space_t::generic,
|
||||
address_space_t DstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run_optimized_src_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst) const
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void
|
||||
Run_optimized_src_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst,
|
||||
integral_constant<AddressSpace, SrcAddressSpace>,
|
||||
integral_constant<AddressSpace, DstAddressSpace>) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
@@ -287,14 +293,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
const auto src_coord =
|
||||
src_nonlinear_coord + (linear_dim_data_steps + scalar_id);
|
||||
|
||||
#if 1 // tweaking
|
||||
// this is src compile-time offset
|
||||
const index_t src_linear_offset =
|
||||
src_coord.GetOffset() - src_nonlinear_coord.GetOffset();
|
||||
#else
|
||||
#if CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF // tweaking
|
||||
// this is src compile-time offset
|
||||
const index_t src_linear_offset =
|
||||
src_nonlinear_coord.CalculateOffsetDiff(linear_dim_data_steps + scalar_id);
|
||||
#else
|
||||
// this is src compile-time offset
|
||||
const index_t src_linear_offset =
|
||||
src_coord.GetOffset() - src_nonlinear_coord.GetOffset();
|
||||
#endif
|
||||
|
||||
// Check src vector's padding situation, only check the first data in
|
||||
@@ -302,8 +308,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// the src vector has the same padding situation
|
||||
if(src_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
__buffer_load<SrcData, SrcDataPerAccess>(
|
||||
p_src, src_nonlinear_coord.GetOffset(), src_linear_offset);
|
||||
@@ -360,12 +366,16 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// Will do padding check on src data: Read 0 if src data is in padding area.
|
||||
// Will do padding check on dst data: No write if dst data is in paddin area.
|
||||
// This version is optimized for address calculation of dst tensor
|
||||
// TODO: this function is not compiled to expected ISA
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
address_space_t SrcAddressSpace = address_space_t::generic,
|
||||
address_space_t DstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run_optimized_dst_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst) const
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void
|
||||
Run_optimized_dst_address_calculation(const SrcData* p_src,
|
||||
DstData* p_dst,
|
||||
integral_constant<AddressSpace, SrcAddressSpace>,
|
||||
integral_constant<AddressSpace, DstAddressSpace>) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
@@ -476,14 +486,14 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
const auto dst_coord =
|
||||
dst_nonlinear_coord + (linear_dim_data_steps + scalar_id);
|
||||
|
||||
#if 1 // tweaking
|
||||
// this is dst compile-time offset
|
||||
const index_t dst_linear_offset =
|
||||
dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset();
|
||||
#else
|
||||
#if CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF // tweaking
|
||||
// this is dst compile-time offset
|
||||
const index_t dst_linear_offset =
|
||||
dst_nonlinear_coord.CalculateOffsetDiff(linear_dim_data_steps + scalar_id);
|
||||
#else
|
||||
// this is dst compile-time offset
|
||||
const index_t dst_linear_offset =
|
||||
dst_coord.GetOffset() - dst_nonlinear_coord.GetOffset();
|
||||
#endif
|
||||
|
||||
// Check dst vector's padding situation, only check the first data in
|
||||
@@ -491,8 +501,8 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
// the dst vector has the same padding situation
|
||||
if(dst_coord.IsUpperIndexMappedToValidOffset())
|
||||
{
|
||||
static_if<DstAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
__buffer_store<DstData, DstDataPerAccess>(
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]),
|
||||
p_dst,
|
||||
@@ -514,6 +524,15 @@ struct ThreadwiseGenericTensorSliceCopy_v4r2
|
||||
});
|
||||
}
|
||||
|
||||
__device__ static constexpr bool HasWorkingOptimizedAddressCalculation()
|
||||
{
|
||||
#if CK_EXPERIMENTAL_THREADWISE_COPY_V4R2_USE_OPTIMIZED_ADDRESS_CACLULATION // tweaking
|
||||
return true;
|
||||
#else
|
||||
return false;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSliceWindow(const T& step_sizes_,
|
||||
integral_constant<bool, PositiveDirection>)
|
||||
|
||||
@@ -2,261 +2,12 @@
|
||||
#define CK_THREADWISE_GENERIC_TENSOR_SLICE_COPY_DEPRECATED_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
#include "ConstantMergedTensorDescriptor.hpp"
|
||||
#include "ConstantTensorDescriptor_deprecated.hpp"
|
||||
#include "ConstantMergedTensorDescriptor_deprecated.hpp"
|
||||
#include "tensor_coordinate_deprecated.hpp"
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_INTRINSIC
|
||||
#define CK_USE_AMD_INTRINSIC 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 1
|
||||
#endif
|
||||
|
||||
namespace ck {
|
||||
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the dimensions of vector access to be different on src and dst.
|
||||
// It also allows the vector size to be different on src and dst.
|
||||
// It also allows order of access to be different on src and dst.
|
||||
// It use register as buffer to hold all data moving from src to dst.
|
||||
// It is designed for copying small amount of data, and src and dst are
|
||||
// device memory or LDS.
|
||||
// When copying large amout of data, let's hope compiler will reduce register
|
||||
// used for the buffer.
|
||||
template <typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename SrcDimAccessOrder,
|
||||
typename DstDimAccessOrder,
|
||||
index_t SrcVectorAccessDim,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v1r1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetSize();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r1(
|
||||
Array<index_t, nDim> src_slice_origin, Array<index_t, nDim> dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
nDim == DstDesc::GetNumOfDimension() && nDim == SliceLengths::GetSize() &&
|
||||
nDim == SrcDimAccessOrder::GetSize() &&
|
||||
nDim == DstDimAccessOrder::GetSize(),
|
||||
"wrong! # of dimensions not the same");
|
||||
|
||||
static_assert(is_valid_sequence_map<SrcDimAccessOrder>::value &&
|
||||
is_valid_sequence_map<DstDimAccessOrder>::value,
|
||||
"wrong! map is not valid");
|
||||
|
||||
static_assert(SliceLengths{}[SrcVectorAccessDim] % SrcDataPerAccess == 0 &&
|
||||
SliceLengths{}[DstVectorAccessDim] % DstDataPerAccess == 0,
|
||||
"wrong! cannot evenly divide");
|
||||
|
||||
// check vectorized memory access
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
|
||||
static_if<!SrcDesc::ContainMultipleOriginalDimensions(src_vector_access_dim)>{}(
|
||||
[&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetStride(src_vector_access_dim) == 1 || SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(SrcDesc{}).GetLastOriginalDimensionStride(src_vector_access_dim) == 1 ||
|
||||
SrcDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
|
||||
static_if<!DstDesc::ContainMultipleOriginalDimensions(dst_vector_access_dim)>{}(
|
||||
[&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetStride(dst_vector_access_dim) == 1 || DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
})
|
||||
.Else([&](auto fwd) {
|
||||
static_assert(
|
||||
(fwd(DstDesc{}).GetLastOriginalDimensionStride(dst_vector_access_dim) == 1 ||
|
||||
DstDataPerAccess == 1),
|
||||
"wrong! vectorized access is allowed only if stride == 1");
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r1()
|
||||
: ThreadwiseGenericTensorSliceCopy_v1r1(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
__device__ void SetSrcSliceOrigin(Array<index_t, nDim> src_slice_origin)
|
||||
{
|
||||
mSrcSliceOrigin = src_slice_origin;
|
||||
}
|
||||
|
||||
__device__ void SetDstSliceOrigin(Array<index_t, nDim> dst_slice_origin)
|
||||
{
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
TData p_buffer_[buffer_desc.GetElementSpace()];
|
||||
TData* p_buffer = p_buffer_;
|
||||
|
||||
// copy data from src into buffer
|
||||
{
|
||||
using vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto src_vector_access_dim = Number<SrcVectorAccessDim>{};
|
||||
constexpr auto src_data_per_access = Number<SrcDataPerAccess>{};
|
||||
|
||||
constexpr auto src_access_lengths = SliceLengths::Modify(
|
||||
src_vector_access_dim,
|
||||
SliceLengths::Get(src_vector_access_dim) / src_data_per_access);
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
|
||||
static_ford<decltype(src_access_lengths), SrcDimAccessOrder>{}([&](auto src_access_id) {
|
||||
constexpr auto src_data_begin_id = src_access_id.Modify(
|
||||
src_vector_access_dim,
|
||||
src_access_id[src_vector_access_dim] * src_data_per_access);
|
||||
|
||||
const index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(mSrcSliceOrigin + src_data_begin_id);
|
||||
|
||||
// load vector from src
|
||||
const vector_t vector_data = *reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
|
||||
// unpack vector into buffer
|
||||
static_for<0, SrcDataPerAccess, 1>{}([&](auto i) {
|
||||
constexpr auto scalar_id =
|
||||
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(src_vector_access_dim,
|
||||
i);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(src_data_begin_id + scalar_id);
|
||||
|
||||
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i];
|
||||
});
|
||||
});
|
||||
#else
|
||||
ford<decltype(src_access_lengths), SrcDimAccessOrder>{}([&](auto src_access_id) {
|
||||
auto src_data_begin_id = src_access_id;
|
||||
src_data_begin_id(src_vector_access_dim) =
|
||||
src_access_id[src_vector_access_dim] * src_data_per_access;
|
||||
|
||||
const index_t src_offset =
|
||||
SrcDesc::GetOffsetFromMultiIndex(mSrcSliceOrigin + src_data_begin_id);
|
||||
|
||||
// load vector from src
|
||||
const vector_t vector_data = *reinterpret_cast<const vector_t*>(&p_src[src_offset]);
|
||||
|
||||
// unpack vector into buffer
|
||||
for(index_t i = 0; i < SrcDataPerAccess; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(src_vector_access_dim) = i;
|
||||
|
||||
const index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(src_data_begin_id + scalar_id);
|
||||
|
||||
p_buffer[buffer_offset] = reinterpret_cast<const TData*>(&vector_data)[i];
|
||||
}
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
// copy data from buffer to dst
|
||||
{
|
||||
using vector_t = typename vector_type<TData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto dst_vector_access_dim = Number<DstVectorAccessDim>{};
|
||||
constexpr auto dst_data_per_access = Number<DstDataPerAccess>{};
|
||||
|
||||
constexpr auto dst_access_lengths = SliceLengths::Modify(
|
||||
dst_vector_access_dim,
|
||||
SliceLengths::Get(dst_vector_access_dim) / dst_data_per_access);
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1
|
||||
static_ford<decltype(dst_access_lengths), DstDimAccessOrder>{}([&](auto dst_access_id) {
|
||||
constexpr auto dst_data_begin_id = dst_access_id.Modify(
|
||||
dst_vector_access_dim,
|
||||
dst_access_id[dst_vector_access_dim] * dst_data_per_access);
|
||||
|
||||
vector_t vector_data{};
|
||||
|
||||
// pack vector from buffer
|
||||
static_for<0, DstDataPerAccess, 1>{}([&](auto i) {
|
||||
constexpr auto scalar_id =
|
||||
typename uniform_sequence_gen<nDim, 0>::type{}.Modify(dst_vector_access_dim,
|
||||
i);
|
||||
|
||||
constexpr index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(dst_data_begin_id + scalar_id);
|
||||
|
||||
reinterpret_cast<TData*>(&vector_data)[i] = p_buffer[buffer_offset];
|
||||
});
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(mDstSliceOrigin + dst_data_begin_id);
|
||||
|
||||
// store vector into dst
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = vector_data;
|
||||
});
|
||||
#else
|
||||
ford<decltype(dst_access_lengths), DstDimAccessOrder>{}([&](auto dst_access_id) {
|
||||
auto dst_data_begin_id = dst_access_id;
|
||||
dst_data_begin_id(dst_vector_access_dim) =
|
||||
dst_access_id[dst_vector_access_dim] * dst_data_per_access;
|
||||
|
||||
vector_t vector_data{};
|
||||
|
||||
// pack vector from buffer
|
||||
for(index_t i = 0; i < DstDataPerAccess; ++i)
|
||||
{
|
||||
auto scalar_id = make_zero_array<index_t, nDim>();
|
||||
scalar_id(dst_vector_access_dim) = i;
|
||||
|
||||
const index_t buffer_offset =
|
||||
buffer_desc.GetOffsetFromMultiIndex(dst_data_begin_id + scalar_id);
|
||||
|
||||
reinterpret_cast<TData*>(&vector_data)[i] = p_buffer[buffer_offset];
|
||||
}
|
||||
|
||||
const index_t dst_offset =
|
||||
DstDesc::GetOffsetFromMultiIndex(mDstSliceOrigin + dst_data_begin_id);
|
||||
|
||||
// store vector into dst
|
||||
*reinterpret_cast<vector_t*>(&p_dst[dst_offset]) = vector_data;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
Array<index_t, nDim> mSrcSliceOrigin;
|
||||
Array<index_t, nDim> mDstSliceOrigin;
|
||||
};
|
||||
|
||||
// This threadwise copy allow vector access of src and dst.
|
||||
// It allows the vector size to be different on src and dst.
|
||||
// The dimensions of vector access should be the same on src and dst.
|
||||
@@ -270,11 +21,11 @@ template <typename SrcDesc,
|
||||
index_t VectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v1r2
|
||||
struct ThreadwiseGenericTensorSliceCopy_v1r2_deprecated
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetSize();
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2(
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2_deprecated(
|
||||
Array<index_t, nDim> src_slice_origin, Array<index_t, nDim> dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
@@ -313,9 +64,9 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2()
|
||||
: ThreadwiseGenericTensorSliceCopy_v1r2(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v1r2_deprecated()
|
||||
: ThreadwiseGenericTensorSliceCopy_v1r2_deprecated(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
@@ -329,11 +80,11 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
|
||||
mDstSliceOrigin = dst_slice_origin;
|
||||
}
|
||||
|
||||
template <typename TData>
|
||||
__device__ void Run(const TData* p_src, TData* p_dst) const
|
||||
template <class SrcData, class DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
using src_vector_t = typename vector_type<TData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<TData, DstDataPerAccess>::MemoryType;
|
||||
using src_vector_t = typename vector_type<SrcData, SrcDataPerAccess>::MemoryType;
|
||||
using dst_vector_t = typename vector_type<DstData, DstDataPerAccess>::MemoryType;
|
||||
|
||||
constexpr auto vector_access_dim = Number<VectorAccessDim>{};
|
||||
|
||||
@@ -345,46 +96,6 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
|
||||
constexpr auto long_vector_access_lengths = SliceLengths::Modify(
|
||||
vector_access_dim, SliceLengths::Get(vector_access_dim) / long_vector_size);
|
||||
|
||||
#if CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2
|
||||
static_ford<decltype(long_vector_access_lengths), DimAccessOrder>{}([&](
|
||||
auto long_vector_access_id) {
|
||||
|
||||
// data id w.r.t slicing-window
|
||||
constexpr auto long_vector_data_begin_id = long_vector_access_id.Modify(
|
||||
vector_access_dim, long_vector_access_id[vector_access_dim] * long_vector_size);
|
||||
|
||||
// buffer to hold a long-vector
|
||||
TData p_long_vector[long_vector_size];
|
||||
|
||||
// load data from src to the long-vector buffer
|
||||
static_for<0, long_vector_size / src_data_per_access, 1>{}([&](auto i) {
|
||||
constexpr auto scalar_id = typename uniform_sequence_gen<nDim, 0>::type{}.Modify(
|
||||
vector_access_dim, i * src_data_per_access);
|
||||
|
||||
const index_t src_offset = SrcDesc::GetOffsetFromMultiIndex(
|
||||
mSrcSliceOrigin + (long_vector_data_begin_id + scalar_id));
|
||||
|
||||
constexpr index_t buffer_offset = i * src_data_per_access;
|
||||
|
||||
*reinterpret_cast<src_vector_t*>(&p_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
|
||||
});
|
||||
|
||||
// store data from the long-vector buffer to dst
|
||||
static_for<0, long_vector_size / dst_data_per_access, 1>{}([&](auto i) {
|
||||
constexpr auto scalar_id = typename uniform_sequence_gen<nDim, 0>::type{}.Modify(
|
||||
vector_access_dim, i * dst_data_per_access);
|
||||
|
||||
constexpr index_t buffer_offset = i * dst_data_per_access;
|
||||
|
||||
const index_t dst_offset = DstDesc::GetOffsetFromMultiIndex(
|
||||
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id));
|
||||
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_long_vector[buffer_offset]);
|
||||
});
|
||||
});
|
||||
#else
|
||||
ford<decltype(long_vector_access_lengths), DimAccessOrder>{}(
|
||||
[&](auto long_vector_access_id) {
|
||||
|
||||
@@ -394,7 +105,8 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
|
||||
long_vector_size * long_vector_access_id[vector_access_dim];
|
||||
|
||||
// buffer to hold a long-vector
|
||||
TData p_long_vector[long_vector_size];
|
||||
SrcData p_src_long_vector[long_vector_size];
|
||||
DstData p_dst_long_vector[long_vector_size];
|
||||
|
||||
// load data from src to the long-vector buffer
|
||||
for(index_t i = 0; i < long_vector_size / src_data_per_access; ++i)
|
||||
@@ -407,10 +119,16 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
|
||||
|
||||
const index_t buffer_offset = i * src_data_per_access;
|
||||
|
||||
*reinterpret_cast<src_vector_t*>(&p_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<src_vector_t*>(&p_src_long_vector[buffer_offset]) =
|
||||
*reinterpret_cast<const src_vector_t*>(&p_src[src_offset]);
|
||||
}
|
||||
|
||||
// type conversion
|
||||
for(index_t i = 0; i < long_vector_size; ++i)
|
||||
{
|
||||
p_dst_long_vector[i] = type_convert<DstData>{}(p_src_long_vector[i]);
|
||||
}
|
||||
|
||||
// store data from the long-vector buffer to dst
|
||||
for(index_t i = 0; i < long_vector_size / dst_data_per_access; ++i)
|
||||
{
|
||||
@@ -423,10 +141,9 @@ struct ThreadwiseGenericTensorSliceCopy_v1r2
|
||||
mDstSliceOrigin + (long_vector_data_begin_id + scalar_id));
|
||||
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst[dst_offset]) =
|
||||
*reinterpret_cast<dst_vector_t*>(&p_long_vector[buffer_offset]);
|
||||
*reinterpret_cast<dst_vector_t*>(&p_dst_long_vector[buffer_offset]);
|
||||
}
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
@@ -453,7 +170,7 @@ template <typename SrcDesc,
|
||||
index_t DstVectorAccessDim,
|
||||
index_t SrcDataPerAccess,
|
||||
index_t DstDataPerAccess>
|
||||
struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
struct ThreadwiseGenericTensorSliceCopy_v2r1_deprecated
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::GetSize();
|
||||
|
||||
@@ -462,8 +179,8 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
using SrcCoordinate = typename TensorCoordinate_deprecated<SrcDesc>::type;
|
||||
using DstCoordinate = typename TensorCoordinate_deprecated<DstDesc>::type;
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1(const Index& src_slice_origin,
|
||||
const Index& dst_slice_origin)
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1_deprecated(
|
||||
const Index& src_slice_origin, const Index& dst_slice_origin)
|
||||
: mSrcSliceOrigin(src_slice_origin), mDstSliceOrigin(dst_slice_origin)
|
||||
{
|
||||
static_assert(nDim == SrcDesc::GetNumOfDimension() &&
|
||||
@@ -511,9 +228,9 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
});
|
||||
}
|
||||
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1()
|
||||
: ThreadwiseGenericTensorSliceCopy_v2r1(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
__device__ constexpr ThreadwiseGenericTensorSliceCopy_v2r1_deprecated()
|
||||
: ThreadwiseGenericTensorSliceCopy_v2r1_deprecated(make_zero_array<index_t, nDim>(),
|
||||
make_zero_array<index_t, nDim>())
|
||||
{
|
||||
}
|
||||
|
||||
@@ -539,9 +256,12 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
address_space_t SrcAddressSpace = address_space_t::generic,
|
||||
address_space_t DstAddressSpace = address_space_t::generic>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
AddressSpace SrcAddressSpace,
|
||||
AddressSpace DstAddressSpace>
|
||||
__device__ void Run(const SrcData* p_src,
|
||||
DstData* p_dst,
|
||||
integral_constant<AddressSpace, SrcAddressSpace>,
|
||||
integral_constant<AddressSpace, DstAddressSpace>) const
|
||||
{
|
||||
constexpr auto buffer_desc = make_ConstantTensorDescriptor_packed(SliceLengths{});
|
||||
|
||||
@@ -613,10 +333,10 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
// 2. src_normal_offset must be calculatd at compile time (guaranteed by
|
||||
// algorithm)
|
||||
// 3. src_merged_offset can be runtime value (no assumption imposed)
|
||||
static_if<SrcAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
static_if<SrcAddressSpace == AddressSpace::global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
vector_data = __buffer_load<SrcData, SrcDataPerAccess>(
|
||||
p_src, src_merged_offset, src_normal_offset);
|
||||
fwd(p_src), src_merged_offset, src_normal_offset);
|
||||
#else
|
||||
vector_data = *reinterpret_cast<const src_vector_t*>(
|
||||
&p_src[src_normal_offset + src_merged_offset]);
|
||||
@@ -722,10 +442,10 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
// 2. dst_normal_offset must be calculatd at compile time (guaranteed by
|
||||
// algorithm)
|
||||
// 3. dst_merged_offset can be runtime value (no assumption imposed)
|
||||
static_if<DstAddressSpace == address_space_t::global>{}([&](auto) {
|
||||
#if CK_USE_AMD_INTRINSIC && CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE
|
||||
static_if<DstAddressSpace == AddressSpace::global>{}([&](auto fwd) {
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
__buffer_store<SrcData, DstDataPerAccess>(
|
||||
vector_data, p_dst, dst_merged_offset, dst_normal_offset);
|
||||
vector_data, fwd(p_dst), dst_merged_offset, dst_normal_offset);
|
||||
#else
|
||||
*reinterpret_cast<dst_vector_t*>(
|
||||
&p_dst[dst_normal_offset + dst_merged_offset]) = vector_data;
|
||||
@@ -740,6 +460,15 @@ struct ThreadwiseGenericTensorSliceCopy_v2r1
|
||||
}
|
||||
}
|
||||
|
||||
template <typename SrcData, typename DstData>
|
||||
__device__ void Run(const SrcData* p_src, DstData* p_dst) const
|
||||
{
|
||||
constexpr auto generic_address_space =
|
||||
integral_constant<AddressSpace, AddressSpace::generic>{};
|
||||
|
||||
Run(p_src, p_dst, generic_address_space, generic_address_space);
|
||||
}
|
||||
|
||||
// T can be Sequence or Array
|
||||
template <typename T, bool PositiveDirection>
|
||||
__device__ void MoveSrcSliceWindow(T step_sizes, integral_constant<bool, PositiveDirection>)
|
||||
|
||||
284
composable_kernel/include/utility/amd_buffer_addressing.hpp
Normal file
284
composable_kernel/include/utility/amd_buffer_addressing.hpp
Normal file
@@ -0,0 +1,284 @@
|
||||
#ifndef CK_AMD_BUFFER_ADDRESSING_HPP
|
||||
#define CK_AMD_BUFFER_ADDRESSING_HPP
|
||||
|
||||
#include "float_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// For 128bit SGPRs in buffer_load and buffer_store instructions
|
||||
// https://rocm-documentation.readthedocs.io/en/latest/GCN_ISA_Manuals/testdocbook.html#vector-memory-buffer-instructions
|
||||
template <typename T>
|
||||
union BufferLoadStoreDwordConfig
|
||||
{
|
||||
int32x4_t data;
|
||||
T* address[2];
|
||||
int32_t range[4];
|
||||
};
|
||||
|
||||
__device__ float __llvm_amdgcn_buffer_load(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load");
|
||||
|
||||
__device__ float2_t __llvm_amdgcn_buffer_loadx2(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.dwordx2");
|
||||
|
||||
__device__ float4_t __llvm_amdgcn_buffer_loadx4(int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.dwordx4");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store(float vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_storex2(float2_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.dwordx2");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_storex4(float4_t vdata,
|
||||
int32x4_t rsrc,
|
||||
index_t vindex,
|
||||
index_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.dwordx4");
|
||||
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ typename vector_type<T, VectorSize>::MemoryType
|
||||
__buffer_load(const T* p_src_block, index_t src_thread_data_offset, index_t src_const_data_offset);
|
||||
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ void __buffer_store(const typename vector_type<T, VectorSize>::MemoryType& src,
|
||||
T* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset);
|
||||
|
||||
template <>
|
||||
__device__ float __buffer_load<float, 1>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
float dst;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
dst = __llvm_amdgcn_buffer_load(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
asm volatile(
|
||||
"\n \
|
||||
buffer_load_dword %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
|
||||
#endif
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float2_t __buffer_load<float, 2>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
float2_t dst;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
dst = __llvm_amdgcn_buffer_loadx2(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
asm volatile(
|
||||
"\n \
|
||||
buffer_load_dwordx2 %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
|
||||
#endif
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ float4_t __buffer_load<float, 4>(const float* p_src_block,
|
||||
index_t src_thread_data_offset,
|
||||
index_t src_const_data_offset)
|
||||
{
|
||||
float4_t dst;
|
||||
|
||||
index_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
index_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> src_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
src_block_config.address[0] = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
src_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
src_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
dst = __llvm_amdgcn_buffer_loadx4(
|
||||
src_block_config.data, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
#else
|
||||
asm volatile(
|
||||
"\n \
|
||||
buffer_load_dwordx4 %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_config.data), "s"(src_const_addr_offset));
|
||||
#endif
|
||||
|
||||
return dst;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void __buffer_store<float, 1>(const float& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_store(src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
asm volatile("\n \
|
||||
buffer_store_dword %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_config.data),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void __buffer_store<float, 2>(const float2_t& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_storex2(src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
asm volatile("\n \
|
||||
buffer_store_dwordx2 %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_config.data),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void __buffer_store<float, 4>(const float4_t& src,
|
||||
float* p_dst_block,
|
||||
index_t dst_thread_data_offset,
|
||||
index_t dst_const_data_offset)
|
||||
{
|
||||
index_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
index_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
BufferLoadStoreDwordConfig<float> dst_block_config;
|
||||
|
||||
// fill in byte 0 - 1
|
||||
dst_block_config.address[0] = p_dst_block;
|
||||
// fill in byte 2
|
||||
dst_block_config.range[2] = -1;
|
||||
// fill in byte 3
|
||||
dst_block_config.range[3] = 0x00027000;
|
||||
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
__llvm_amdgcn_buffer_storex4(src,
|
||||
dst_block_config.data,
|
||||
0,
|
||||
dst_thread_addr_offset + dst_const_addr_offset,
|
||||
false,
|
||||
false);
|
||||
#else
|
||||
asm volatile("\n \
|
||||
buffer_store_dwordx4 %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_config.data),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,84 +1,31 @@
|
||||
#ifndef CK_AMD_INLINE_ASM_HPP
|
||||
#define CK_AMD_INLINE_ASM_HPP
|
||||
|
||||
#include "vector_type.hpp"
|
||||
#include "float_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// cast a pointer of LDS to its address
|
||||
extern "C" __attribute__((address_space(3))) __device__ void* __to_local(void* p);
|
||||
|
||||
__device__ void vmcnt(index_t cnt)
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x2(float a, float b0, float b1, float& c0, float& c1)
|
||||
{
|
||||
if(cnt == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(0) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(1) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 2)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 4)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt vmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
// disable inline asm due to the compiler issue: SWDEV-202749
|
||||
///\to-do: enable the inline asm after the compiler fix
|
||||
#if CK_WORKAROUND_SWDEV_202749
|
||||
c0 += a * b0;
|
||||
c1 += a * b1;
|
||||
#else
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %2, %3 \n \
|
||||
v_mac_f32 %1, %2, %4 \n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1)
|
||||
: "v"(a), "v"(b0), "v"(b1), "0"(c0), "1"(c1));
|
||||
#endif
|
||||
}
|
||||
|
||||
__device__ void lgkmcnt(index_t cnt)
|
||||
{
|
||||
if(cnt == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(0) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 1)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(1) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 2)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(2) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(3) \n \
|
||||
" ::);
|
||||
}
|
||||
else if(cnt == 4)
|
||||
{
|
||||
asm volatile("\n \
|
||||
s_waitcnt lgkmcnt(4) \n \
|
||||
" ::);
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void outerProduct1x4(const float* a, const float* b, float* c)
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x4(
|
||||
float a, float b0, float b1, float b2, float b3, float& c0, float& c1, float& c2, float& c3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_mac_f32 %0, %4, %5 \n \
|
||||
@@ -86,596 +33,122 @@ __device__ void outerProduct1x4(const float* a, const float* b, float* c)
|
||||
v_mac_f32 %2, %4, %7 \n \
|
||||
v_mac_f32 %3, %4, %8 \n \
|
||||
"
|
||||
: "=v"(c[0]), "=v"(c[1]), "=v"(c[2]), "=v"(c[3])
|
||||
: "v"(a[0]),
|
||||
"v"(b[0]),
|
||||
"v"(b[1]),
|
||||
"v"(b[2]),
|
||||
"v"(b[3]),
|
||||
"0"(c[0]),
|
||||
"1"(c[1]),
|
||||
"2"(c[2]),
|
||||
"3"(c[3]));
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3)
|
||||
: "v"(a), "v"(b0), "v"(b1), "v"(b2), "v"(b3), "0"(c0), "1"(c1), "2"(c2), "3"(c3));
|
||||
}
|
||||
|
||||
__device__ void outerProduct1x4(const float& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c)
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x2(half2_t a, half2_t b0, half2_t b1, float& c0, float& c1)
|
||||
{
|
||||
outerProduct1x4(&a, reinterpret_cast<const float*>(&b), reinterpret_cast<float*>(&c));
|
||||
}
|
||||
|
||||
__device__ void outerProduct2x4(const vector_type<float, 2>::MemoryType& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c0,
|
||||
vector_type<float, 4>::MemoryType& c1)
|
||||
{
|
||||
outerProduct1x4(a.x, b, c0);
|
||||
outerProduct1x4(a.y, b, c1);
|
||||
}
|
||||
|
||||
__device__ void outerProduct4x4(const vector_type<float, 4>::MemoryType& a,
|
||||
const vector_type<float, 4>::MemoryType& b,
|
||||
vector_type<float, 4>::MemoryType& c0,
|
||||
vector_type<float, 4>::MemoryType& c1,
|
||||
vector_type<float, 4>::MemoryType& c2,
|
||||
vector_type<float, 4>::MemoryType& c3)
|
||||
{
|
||||
outerProduct1x4(a.x, b, c0);
|
||||
outerProduct1x4(a.y, b, c1);
|
||||
outerProduct1x4(a.z, b, c2);
|
||||
outerProduct1x4(a.w, b, c3);
|
||||
}
|
||||
|
||||
__device__ void outerProduct8x8(const vector_type<float, 4>::MemoryType* a,
|
||||
const vector_type<float, 4>::MemoryType* b,
|
||||
vector_type<float, 4>::MemoryType* c)
|
||||
{
|
||||
outerProduct4x4(a[0], b[0], c[0], c[2], c[4], c[6]);
|
||||
outerProduct4x4(a[0], b[1], c[1], c[3], c[5], c[7]);
|
||||
outerProduct4x4(a[1], b[0], c[8], c[10], c[12], c[14]);
|
||||
outerProduct4x4(a[1], b[1], c[9], c[11], c[13], c[15]);
|
||||
}
|
||||
|
||||
__device__ void ds_read_b128(vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0)
|
||||
{
|
||||
if(offset == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:0\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 64)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:64\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 128)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:128\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 192)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:192\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 256)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:256\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 320)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:320\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 384)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:384\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 448)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:448\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 512)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:512\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 576)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:576\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 640)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:640\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 704)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:704\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 768)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:768\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 832)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:832\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 896)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:896\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 960)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:960\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1024)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1024\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1088)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1088\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1152)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1152\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1216)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1216\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1280)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1280\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1344)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1344\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1408)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1408\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1472)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1472\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1536)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1536\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1600)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1600\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1664)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1664\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1728)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1728\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1792)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1792\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1856)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1856\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1920)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1920\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 1984)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:1984\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2048)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2048\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2112)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2112\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2176)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2176\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2240)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2240\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2304)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2304\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2368)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2368\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2432)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2432\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2496)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2496\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2560)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2560\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2624)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2624\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2688)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2688\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2752)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2752\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2816)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2816\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2880)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2880\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 2944)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:2944\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3008)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3008\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3072)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3072\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3136)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3136\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3200)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3200\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3264)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3264\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3328)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3328\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3392)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3392\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3456)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3456\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3520)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3520\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3584)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3584\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3648)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3648\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3712)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3712\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3776)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3776\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3840)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3840\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3904)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3904\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 3968)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:3968\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 4032)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:4032\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
if(offset == 4096)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_read_b128 %0, %1 offset:4096\n \
|
||||
"
|
||||
: "=v"(r)
|
||||
: "v"(__to_local(lds)));
|
||||
}
|
||||
}
|
||||
|
||||
__device__ void
|
||||
ds_write_b128(const vector_type<float, 4>::MemoryType& r, void* lds, index_t offset = 0)
|
||||
{
|
||||
if(offset == 0)
|
||||
{
|
||||
asm volatile("\n \
|
||||
ds_write_b128 %0, %1 \n \
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %2, %3 %0\n \
|
||||
v_dot2_f32_f16 %1, %2, %4 %1\n \
|
||||
"
|
||||
:
|
||||
: "v"(__to_local(lds)), "v"(r));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
: "=v"(c0), "=v"(c1) // Dest registers
|
||||
: "v"(a), // 1st Src register for 1 half2 registers
|
||||
"v"(b0), // 2nd Src register
|
||||
"v"(b1),
|
||||
"0"(c0), // 3rd Src register
|
||||
"1"(c1));
|
||||
}
|
||||
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x2(half4_t a, half4_t b0, half4_t b1, float& c0, float& c1)
|
||||
{
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %2, %4 %0\n \
|
||||
v_dot2_f32_f16 %1, %2, %6 %1\n \
|
||||
v_dot2_f32_f16 %0, %3, %5 %0\n \
|
||||
v_dot2_f32_f16 %1, %3, %7 %1\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1) // Dest registers
|
||||
: "v"(p_a_half2[0]),
|
||||
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
|
||||
"v"(p_b0_half2[0]),
|
||||
"v"(p_b0_half2[1]),
|
||||
"v"(p_b1_half2[0]),
|
||||
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
|
||||
"0"(c0),
|
||||
"1"(c1)); // 3rd Src Acc registers for 2 half2 registers
|
||||
}
|
||||
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x4(half2_t a,
|
||||
half2_t b0,
|
||||
half2_t b1,
|
||||
half2_t b2,
|
||||
half2_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %4, %5 %0\n \
|
||||
v_dot2_f32_f16 %1, %4, %6 %1\n \
|
||||
v_dot2_f32_f16 %2, %4, %7 %2\n \
|
||||
v_dot2_f32_f16 %3, %4, %8 %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
|
||||
: "v"(a), // 1st Src register for 1 half2 registers
|
||||
"v"(b0), // 2nd Src register
|
||||
"v"(b1),
|
||||
"v"(b2),
|
||||
"v"(b3),
|
||||
"0"(c0), // 3rd Src register
|
||||
"1"(c1),
|
||||
"2"(c2),
|
||||
"3"(c3));
|
||||
}
|
||||
|
||||
// outer-product: c[i,j] += inner_product(a[i], b[j])
|
||||
__device__ void __outer_product_1x4(half4_t a,
|
||||
half4_t b0,
|
||||
half4_t b1,
|
||||
half4_t b2,
|
||||
half4_t b3,
|
||||
float& c0,
|
||||
float& c1,
|
||||
float& c2,
|
||||
float& c3)
|
||||
{
|
||||
const half2_t* p_a_half2 = reinterpret_cast<const half2_t*>(&a);
|
||||
const half2_t* p_b0_half2 = reinterpret_cast<const half2_t*>(&b0);
|
||||
const half2_t* p_b1_half2 = reinterpret_cast<const half2_t*>(&b1);
|
||||
const half2_t* p_b2_half2 = reinterpret_cast<const half2_t*>(&b2);
|
||||
const half2_t* p_b3_half2 = reinterpret_cast<const half2_t*>(&b3);
|
||||
|
||||
// do dot2 two times
|
||||
asm volatile("\n \
|
||||
v_dot2_f32_f16 %0, %4, %6 %0\n \
|
||||
v_dot2_f32_f16 %1, %4, %8 %1\n \
|
||||
v_dot2_f32_f16 %2, %4, %10 %2\n \
|
||||
v_dot2_f32_f16 %3, %4, %12 %3\n \
|
||||
v_dot2_f32_f16 %0, %5, %7 %0\n \
|
||||
v_dot2_f32_f16 %1, %5, %9 %1\n \
|
||||
v_dot2_f32_f16 %2, %5, %11 %2\n \
|
||||
v_dot2_f32_f16 %3, %5, %13 %3\n \
|
||||
"
|
||||
: "=v"(c0), "=v"(c1), "=v"(c2), "=v"(c3) // Dest registers
|
||||
: "v"(p_a_half2[0]),
|
||||
"v"(p_a_half2[1]), // 1st Src registers for 2 half2 registers
|
||||
"v"(p_b0_half2[0]),
|
||||
"v"(p_b0_half2[1]),
|
||||
"v"(p_b1_half2[0]),
|
||||
"v"(p_b1_half2[1]), // 2nd Src registers for 2 half2 registers
|
||||
"v"(p_b2_half2[0]),
|
||||
"v"(p_b2_half2[1]),
|
||||
"v"(p_b3_half2[0]),
|
||||
"v"(p_b3_half2[1]), // 2nd Src registers for 2 half2 registers
|
||||
"0"(c0),
|
||||
"1"(c1),
|
||||
"2"(c2),
|
||||
"3"(c3)); // 3rd Src Acc registers for 2 half2 registers
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
|
||||
@@ -1,332 +0,0 @@
|
||||
#ifndef CK_AMD_INTRINSIC_HPP
|
||||
#define CK_AMD_INTRINSIC_HPP
|
||||
|
||||
#include "vector_type.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
__device__ float __llvm_amdgcn_buffer_load(int32x4_t rsrc,
|
||||
uint32_t vindex,
|
||||
uint32_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load");
|
||||
|
||||
__device__ vector_type<float, 2>::MemoryType
|
||||
__llvm_amdgcn_buffer_loadx2(int32x4_t rsrc,
|
||||
uint32_t vindex,
|
||||
uint32_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.dwordx2");
|
||||
|
||||
__device__ vector_type<float, 4>::MemoryType
|
||||
__llvm_amdgcn_buffer_loadx4(int32x4_t rsrc,
|
||||
uint32_t vindex,
|
||||
uint32_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.load.dwordx4");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_store(float vdata,
|
||||
int32x4_t rsrc,
|
||||
uint32_t vindex,
|
||||
uint32_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_storex2(vector_type<float, 2>::MemoryType vdata,
|
||||
int32x4_t rsrc,
|
||||
uint32_t vindex,
|
||||
uint32_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.dwordx2");
|
||||
|
||||
__device__ void __llvm_amdgcn_buffer_storex4(vector_type<float, 4>::MemoryType vdata,
|
||||
int32x4_t rsrc,
|
||||
uint32_t vindex,
|
||||
uint32_t offset,
|
||||
bool glc,
|
||||
bool slc) __asm("llvm.amdgcn.buffer.store.dwordx4");
|
||||
|
||||
// buffer_load and buffer_store
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ typename vector_type<T, VectorSize>::MemoryType __buffer_load(
|
||||
const T* p_src_block, uint32_t src_thread_data_offset, uint32_t src_const_data_offset);
|
||||
|
||||
template <typename T, index_t VectorSize>
|
||||
__device__ void __buffer_store(const typename vector_type<T, VectorSize>::MemoryType& src,
|
||||
T* p_dst_block,
|
||||
uint32_t dst_thread_data_offset,
|
||||
uint32_t dst_const_data_offset);
|
||||
|
||||
template <>
|
||||
__device__ float __buffer_load<float, 1>(const float* p_src_block,
|
||||
uint32_t src_thread_data_offset,
|
||||
uint32_t src_const_data_offset)
|
||||
{
|
||||
#if 0
|
||||
float dst;
|
||||
|
||||
uint32_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
uint32_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_load_dword %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_setting), "s"(src_const_addr_offset));
|
||||
|
||||
return dst;
|
||||
#else
|
||||
float dst;
|
||||
|
||||
uint32_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
uint32_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
dst = __llvm_amdgcn_buffer_load(
|
||||
src_block_setting, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return dst;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ vector_type<float, 2>::MemoryType __buffer_load<float, 2>(
|
||||
const float* p_src_block, uint32_t src_thread_data_offset, uint32_t src_const_data_offset)
|
||||
{
|
||||
#if 0
|
||||
vector_type<float, 2>::MemoryType dst;
|
||||
|
||||
uint32_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
uint32_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_load_dwordx2 %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_setting), "s"(src_const_addr_offset));
|
||||
|
||||
return dst;
|
||||
#else
|
||||
vector_type<float, 2>::MemoryType dst;
|
||||
|
||||
uint32_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
uint32_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
dst = __llvm_amdgcn_buffer_loadx2(
|
||||
src_block_setting, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return dst;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ vector_type<float, 4>::MemoryType __buffer_load<float, 4>(
|
||||
const float* p_src_block, uint32_t src_thread_data_offset, uint32_t src_const_data_offset)
|
||||
{
|
||||
#if 0
|
||||
vector_type<float, 4>::MemoryType dst;
|
||||
|
||||
uint32_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
uint32_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_load_dwordx4 %0, %1, %2, %3 offen offset:0 \n \
|
||||
s_waitcnt 0 \n \
|
||||
"
|
||||
: "=v"(dst)
|
||||
: "v"(src_thread_addr_offset), "s"(src_block_setting), "s"(src_const_addr_offset));
|
||||
|
||||
return dst;
|
||||
#elif 1
|
||||
vector_type<float, 4>::MemoryType dst;
|
||||
|
||||
uint32_t src_thread_addr_offset = src_thread_data_offset * sizeof(float);
|
||||
uint32_t src_const_addr_offset = src_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t src_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&src_block_setting) = const_cast<float*>(p_src_block);
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&src_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&src_block_setting)[3] = 0x00027000;
|
||||
|
||||
dst = __llvm_amdgcn_buffer_loadx4(
|
||||
src_block_setting, 0, src_thread_addr_offset + src_const_addr_offset, false, false);
|
||||
|
||||
return dst;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void __buffer_store<float, 1>(const float& src,
|
||||
float* p_dst_block,
|
||||
uint32_t dst_thread_data_offset,
|
||||
uint32_t dst_const_data_offset)
|
||||
{
|
||||
#if 0
|
||||
uint32_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
uint32_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t dst_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&dst_block_setting) = p_dst_block;
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&dst_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&dst_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_store_dword %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_setting),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#else
|
||||
uint32_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
uint32_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t dst_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&dst_block_setting) = p_dst_block;
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&dst_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&dst_block_setting)[3] = 0x00027000;
|
||||
|
||||
__llvm_amdgcn_buffer_store(
|
||||
src, dst_block_setting, 0, dst_thread_addr_offset + dst_const_addr_offset, false, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void __buffer_store<float, 2>(const vector_type<float, 2>::MemoryType& src,
|
||||
float* p_dst_block,
|
||||
uint32_t dst_thread_data_offset,
|
||||
uint32_t dst_const_data_offset)
|
||||
{
|
||||
#if 0
|
||||
uint32_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
uint32_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t dst_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&dst_block_setting) = p_dst_block;
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&dst_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&dst_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_store_dwordx2 %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_setting),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#else
|
||||
uint32_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
uint32_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t dst_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&dst_block_setting) = p_dst_block;
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&dst_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&dst_block_setting)[3] = 0x00027000;
|
||||
|
||||
__llvm_amdgcn_buffer_storex2(
|
||||
src, dst_block_setting, 0, dst_thread_addr_offset + dst_const_addr_offset, false, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ void __buffer_store<float, 4>(const vector_type<float, 4>::MemoryType& src,
|
||||
float* p_dst_block,
|
||||
uint32_t dst_thread_data_offset,
|
||||
uint32_t dst_const_data_offset)
|
||||
{
|
||||
#if 0
|
||||
uint32_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
uint32_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t dst_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&dst_block_setting) = p_dst_block;
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&dst_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&dst_block_setting)[3] = 0x00027000;
|
||||
|
||||
asm volatile("\n \
|
||||
buffer_store_dwordx4 %1, %2, %0, %3 offen offset:0 \n \
|
||||
"
|
||||
:
|
||||
: "s"(dst_block_setting),
|
||||
"v"(src),
|
||||
"v"(dst_thread_addr_offset),
|
||||
"s"(dst_const_addr_offset));
|
||||
#else
|
||||
uint32_t dst_thread_addr_offset = dst_thread_data_offset * sizeof(float);
|
||||
uint32_t dst_const_addr_offset = dst_const_data_offset * sizeof(float);
|
||||
|
||||
int32x4_t dst_block_setting{0};
|
||||
// fill in byte 0 - 1
|
||||
*reinterpret_cast<float**>(&dst_block_setting) = p_dst_block;
|
||||
// fill in byte 2
|
||||
reinterpret_cast<int*>(&dst_block_setting)[2] = -1;
|
||||
// fill in byte 3
|
||||
reinterpret_cast<int*>(&dst_block_setting)[3] = 0x00027000;
|
||||
|
||||
__llvm_amdgcn_buffer_storex4(
|
||||
src, dst_block_setting, 0, dst_thread_addr_offset + dst_const_addr_offset, false, false);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -5,14 +5,12 @@
|
||||
#include "utility.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "float_type.hpp"
|
||||
#include "type.hpp"
|
||||
#include "tuple.hpp"
|
||||
#include "math.hpp"
|
||||
#include "vector_type.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "sequence_helper.hpp"
|
||||
#include "array.hpp"
|
||||
#include "array_helper.hpp"
|
||||
#include "functional.hpp"
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
@@ -22,8 +20,8 @@
|
||||
#include "amd_inline_asm.hpp"
|
||||
#endif
|
||||
|
||||
#if CK_USE_AMD_INTRINSIC
|
||||
#include "amd_intrinsic.hpp"
|
||||
#if CK_USE_AMD_BUFFER_ADDRESSING
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
70
composable_kernel/include/utility/config.amd.hpp.in
Normal file
70
composable_kernel/include/utility/config.amd.hpp.in
Normal file
@@ -0,0 +1,70 @@
|
||||
#ifndef CK_CONFIG_AMD_HPP
|
||||
#define CK_CONFIG_AMD_HPP
|
||||
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
#include "bfloat16_dev.hpp"
|
||||
|
||||
// index type: unsigned or signed
|
||||
#define CK_UNSIGNED_INDEX_TYPE 0
|
||||
|
||||
// device backend
|
||||
#define CK_DEVICE_BACKEND_AMD 1
|
||||
|
||||
// AMD inline asm
|
||||
#ifndef CK_USE_AMD_INLINE_ASM
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
|
||||
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
// AMD buffer addressing
|
||||
#ifndef CK_USE_AMD_BUFFER_ADDRESSING
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 1
|
||||
#endif
|
||||
|
||||
// AMD XDLOPS
|
||||
#ifndef CK_USE_AMD_XDLOPS
|
||||
#define CK_USE_AMD_XDLOPS 1
|
||||
#endif
|
||||
|
||||
#ifndef CK_USE_AMD_XDLOPS_INLINE_ASM
|
||||
#define CK_USE_AMD_XDLOPS_INLINE_ASM 1
|
||||
#endif
|
||||
|
||||
// experimental implementation
|
||||
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 1
|
||||
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
|
||||
#define CK_EXPERIMENTAL_THREADWISE_COPY_V4R2_USE_OPTIMIZED_ADDRESS_CACLULATION 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
|
||||
// workaround
|
||||
#define CK_WORKAROUND_SWDEV_202749 1
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
{
|
||||
generic,
|
||||
global
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
using index_t = uint32_t;
|
||||
#else
|
||||
using index_t = int32_t;
|
||||
#endif
|
||||
|
||||
// int32x4_t use by buffer_load and buffer_store llvm intrinsic
|
||||
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
46
composable_kernel/include/utility/config.nvidia.hpp.in
Normal file
46
composable_kernel/include/utility/config.nvidia.hpp.in
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef CK_CONFIG_NVIDIA_HPP
|
||||
#define CK_CONFIG_NVIDIA_HPP
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "cuda_fp16.h"
|
||||
#include "nvToolsExt.h"
|
||||
#include "helper_cuda.h"
|
||||
|
||||
// index type: unsigned or signed
|
||||
#define CK_UNSIGNED_INDEX_TYPE 0
|
||||
|
||||
// device backend
|
||||
#define CK_DEVICE_BACKEND_NVIDIA 1
|
||||
|
||||
// disable AMD inline asm and intrinsic
|
||||
#define CK_USE_AMD_INLINE_ASM 0
|
||||
#define CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM 0
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING 0
|
||||
#define CK_USE_AMD_BUFFER_ADDRESSING_INTRINSIC 0
|
||||
#define CK_USE_AMD_XDLOPS 0
|
||||
#define CK_USE_AMD_XDLOPS_INLINE_ASM 0
|
||||
|
||||
// experimental implementation
|
||||
#define CK_EXPERIMENTAL_BLOCKWISE_GEMM_USE_PIPELINE 0
|
||||
#define CK_EXPERIMENTAL_TENSOR_COORDINATE_USE_CALCULATE_OFFSET_DIFF 0
|
||||
#define CK_EXPERIMENTAL_THREADWISE_COPY_V4R2_USE_OPTIMIZED_ADDRESS_CACLULATION 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum AddressSpace
|
||||
{
|
||||
generic,
|
||||
global = generic
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
using index_t = uint32_t;
|
||||
#else
|
||||
using index_t = int32_t;
|
||||
#endif
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,51 +0,0 @@
|
||||
#ifndef CK_CONFIG_AMD_HPP
|
||||
#define CK_CONFIG_AMD_HPP
|
||||
|
||||
#include "hip/hip_runtime.h"
|
||||
#include "hip/hip_fp16.h"
|
||||
|
||||
#define CK_UNSIGNED_INDEX_TYPE 0
|
||||
#define CK_DEVICE_BACKEND_AMD 1
|
||||
#define CK_USE_AMD_INTRINSIC 1
|
||||
#define CK_USE_AMD_INLINE_ASM 1
|
||||
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 1
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum address_space_t
|
||||
{
|
||||
generic = 0,
|
||||
global = 3
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
using index_t = uint32_t;
|
||||
#else
|
||||
using index_t = int32_t;
|
||||
#endif
|
||||
|
||||
// For some reason, HIP compiler need this definition to generate optimal load and store
|
||||
// instruction
|
||||
typedef float float2_t __attribute__((ext_vector_type(2)));
|
||||
typedef float float4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
typedef int32_t int32x4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
{
|
||||
template <typename X>
|
||||
__device__ T operator()(const X& x) const
|
||||
{
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
@@ -1,53 +0,0 @@
|
||||
#ifndef CK_CONFIG_NVIDIA_HPP
|
||||
#define CK_CONFIG_NVIDIA_HPP
|
||||
|
||||
#include "cuda_runtime.h"
|
||||
#include "cuda_fp16.h"
|
||||
#include "nvToolsExt.h"
|
||||
#include "helper_cuda.h"
|
||||
|
||||
#define CK_UNSIGNED_INDEX_TYPE 0
|
||||
#define CK_DEVICE_BACKEND_NVIDIA 1
|
||||
#define CK_USE_AMD_INTRINSIC 0
|
||||
#define CK_USE_AMD_INLINE_ASM 0
|
||||
#define CK_USE_AMD_INTRINSIC_BUFFER_LOAD_STORE 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_BLOCKWISE_GENERIC_SLICE_COPY_V1 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R1 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V1R2 0
|
||||
#define CK_EXPERIMENTAL_USE_MORE_COMPILE_STATIC_THREADWISE_GENERIC_TENSOR_SLICE_COPY_V2R1 0
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum address_space_t
|
||||
{
|
||||
generic = 0,
|
||||
global = generic
|
||||
};
|
||||
|
||||
#if CK_UNSIGNED_INDEX_TYPE
|
||||
using index_t = uint32_t;
|
||||
#else
|
||||
using index_t = int32_t;
|
||||
#endif
|
||||
|
||||
// For some reason, CUDA need this definition, otherwise
|
||||
// compiler won't generate optimal load and store instruction, and
|
||||
// kernel would produce wrong result, indicating the compiler fail to generate correct
|
||||
// instruction,
|
||||
using float2_t = float2;
|
||||
using float4_t = float4;
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
{
|
||||
template <typename X>
|
||||
__device__ T operator()(const X& x) const
|
||||
{
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
310
composable_kernel/include/utility/float_type.amd.hpp.in
Normal file
310
composable_kernel/include/utility/float_type.amd.hpp.in
Normal file
@@ -0,0 +1,310 @@
|
||||
#ifndef CK_FLOAT_TYPE_AMD_HPP
|
||||
#define CK_FLOAT_TYPE_AMD_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
// For some reason, HIP compiler need this definition to generate optimal ISA
|
||||
// float
|
||||
typedef float float2_t __attribute__((ext_vector_type(2)));
|
||||
typedef float float4_t __attribute__((ext_vector_type(4)));
|
||||
typedef float float32_t __attribute__((ext_vector_type(32)));
|
||||
|
||||
// float16
|
||||
typedef _Float16 half2_t __attribute__((ext_vector_type(2)));
|
||||
typedef _Float16 half4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
// bfloat16
|
||||
typedef ushort ushort2_t __attribute__((ext_vector_type(2)));
|
||||
typedef ushort ushort4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
{
|
||||
typedef struct
|
||||
{
|
||||
T scalar[N];
|
||||
} MemoryType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 1>
|
||||
{
|
||||
using MemoryType = float;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 2>
|
||||
{
|
||||
using MemoryType = float2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
float scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(float s0, float s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 4>
|
||||
{
|
||||
using MemoryType = float4_t;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return 4; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 1>
|
||||
{
|
||||
using MemoryType = half;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 2>
|
||||
{
|
||||
using MemoryType = half2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
half scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(half s0, half s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 4>
|
||||
{
|
||||
using MemoryType = half4_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
half scalar[4];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(half s0, half s1, half s2, half s3)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
data.scalar[2] = s2;
|
||||
data.scalar[3] = s3;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<ushort, 1>
|
||||
{
|
||||
using MemoryType = ushort;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<ushort*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<ushort, 2>
|
||||
{
|
||||
using MemoryType = ushort2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
ushort scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<ushort*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(ushort s0, ushort s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<ushort, 4>
|
||||
{
|
||||
using MemoryType = ushort4_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
ushort scalar[4];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, ushort s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<ushort*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(ushort s0, ushort s1, ushort s2, ushort s3)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
data.scalar[2] = s2;
|
||||
data.scalar[3] = s3;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
{
|
||||
template <typename X>
|
||||
__device__ T operator()(X x) const
|
||||
{
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
template <>
|
||||
__device__ float type_convert<float>::operator()<ushort>(ushort x) const
|
||||
{
|
||||
return bfloat16_to_float(x);
|
||||
}
|
||||
|
||||
template <>
|
||||
template <>
|
||||
__device__ ushort type_convert<ushort>::operator()<float>(float x) const
|
||||
{
|
||||
return float_to_bfloat16(x);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct inner_product_with_conversion
|
||||
{
|
||||
static constexpr auto convert = type_convert<T>();
|
||||
|
||||
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
|
||||
|
||||
__device__ T operator()(half2_t a, half2_t b) const
|
||||
{
|
||||
const half* p_a_half = reinterpret_cast<const half*>(&a);
|
||||
const half* p_b_half = reinterpret_cast<const half*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 2; ++v)
|
||||
{
|
||||
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(half4_t a, half4_t b) const
|
||||
{
|
||||
const half* p_a_half = reinterpret_cast<const half*>(&a);
|
||||
const half* p_b_half = reinterpret_cast<const half*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 4; ++v)
|
||||
{
|
||||
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(ushort2_t a, ushort2_t b) const
|
||||
{
|
||||
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
|
||||
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 2; ++v)
|
||||
{
|
||||
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
|
||||
__device__ T operator()(ushort4_t a, ushort4_t b) const
|
||||
{
|
||||
const ushort* p_a_bfloat16 = reinterpret_cast<const ushort*>(&a);
|
||||
const ushort* p_b_bfloat16 = reinterpret_cast<const ushort*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 4; ++v)
|
||||
{
|
||||
acc += convert(p_a_bfloat16[v]) * convert(p_b_bfloat16[v]);
|
||||
}
|
||||
return acc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
157
composable_kernel/include/utility/float_type.nvidia.hpp.in
Normal file
157
composable_kernel/include/utility/float_type.nvidia.hpp.in
Normal file
@@ -0,0 +1,157 @@
|
||||
#ifndef CK_FLOAT_TYPE_NVIDIA_HPP
|
||||
#define CK_FLOAT_TYPE_NVIDIA_HPP
|
||||
|
||||
#include "number.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
// For some reason, CUDA need this definition, otherwise
|
||||
// compiler won't generate optimal load and store instruction, and
|
||||
// kernel would produce wrong result, indicating the compiler fail to generate correct
|
||||
// instruction,
|
||||
// float
|
||||
using float2_t = float2;
|
||||
using float4_t = float4;
|
||||
|
||||
// float16
|
||||
using half2_t = half2;
|
||||
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
{
|
||||
typedef struct
|
||||
{
|
||||
T scalar[N];
|
||||
} MemoryType;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 1>
|
||||
{
|
||||
using MemoryType = float;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 2>
|
||||
{
|
||||
using MemoryType = float2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
float scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(float s0, float s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 4>
|
||||
{
|
||||
using MemoryType = float4_t;
|
||||
|
||||
__host__ __device__ static constexpr index_t GetSize() { return 4; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 1>
|
||||
{
|
||||
using MemoryType = half;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<half, 2>
|
||||
{
|
||||
using MemoryType = half2_t;
|
||||
|
||||
union DataType
|
||||
{
|
||||
MemoryType vector;
|
||||
half scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, half s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<half*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(half s0, half s1)
|
||||
{
|
||||
DataType data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
// data type conversion
|
||||
template <typename T>
|
||||
struct type_convert
|
||||
{
|
||||
template <typename X>
|
||||
__device__ T operator()(const X& x) const
|
||||
{
|
||||
return static_cast<T>(x);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct inner_product_with_conversion
|
||||
{
|
||||
static constexpr auto convert = type_convert<T>();
|
||||
|
||||
__device__ T operator()(float a, float b) const { return convert(a) * convert(b); }
|
||||
|
||||
__device__ T operator()(half2_t a, half2_t b) const
|
||||
{
|
||||
const half* p_a_half = reinterpret_cast<const half*>(&a);
|
||||
const half* p_b_half = reinterpret_cast<const half*>(&b);
|
||||
|
||||
T acc = 0;
|
||||
for(index_t v = 0; v < 2; ++v)
|
||||
{
|
||||
acc += convert(p_a_half[v]) * convert(p_b_half[v]);
|
||||
}
|
||||
|
||||
return acc;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_ARRAY_HELPER_HPP
|
||||
#define CK_ARRAY_HELPER_HPP
|
||||
#ifndef CK_PRINT_ARRAY_HPP
|
||||
#define CK_PRINT_ARRAY_HPP
|
||||
|
||||
#include "array.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_SEQUENCE_HELPER_HPP
|
||||
#define CK_SEQUENCE_HELPER_HPP
|
||||
#ifndef CK_PRINT_SEQUENCE_HPP
|
||||
#define CK_PRINT_SEQUENCE_HPP
|
||||
|
||||
#include "sequence.hpp"
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
#ifndef CK_VECTOR_TYPE_HPP
|
||||
#define CK_VECTOR_TYPE_HPP
|
||||
|
||||
#include "config.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <class T, index_t N>
|
||||
struct vector_type
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 1>
|
||||
{
|
||||
using MemoryType = float;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 1, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 2>
|
||||
{
|
||||
using MemoryType = float2_t;
|
||||
|
||||
union Data
|
||||
{
|
||||
MemoryType vector;
|
||||
float scalar[2];
|
||||
};
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 2, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
|
||||
__host__ __device__ static MemoryType Pack(float s0, float s1)
|
||||
{
|
||||
Data data;
|
||||
data.scalar[0] = s0;
|
||||
data.scalar[1] = s1;
|
||||
return data.vector;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<float, 4>
|
||||
{
|
||||
using MemoryType = float4_t;
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ static void SetScalar(MemoryType& v, float s, Number<I>)
|
||||
{
|
||||
static_assert(I < 4, "wrong");
|
||||
*(reinterpret_cast<float*>(&v) + I) = s;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<const float, 1>
|
||||
{
|
||||
using MemoryType = const float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<const float, 2>
|
||||
{
|
||||
using MemoryType = const float2_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct vector_type<const float, 4>
|
||||
{
|
||||
using MemoryType = const float4_t;
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
|
||||
#endif
|
||||
Reference in New Issue
Block a user