mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
added implicit gemm v1r3 lds_double_buffer NCHW * CYXK = KNHW, reworked static functionals
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
#include "device.hpp"
|
||||
#include "gridwise_convolution_wrapper.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r1_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r3_chwn_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp"
|
||||
@@ -81,6 +80,8 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
|
||||
#if 0
|
||||
// for 3x3, 34x34, v1r1, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 4;
|
||||
@@ -92,14 +93,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 4;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
@@ -110,11 +103,16 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<4, 4, 2, 4>;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t BlockSize = 128;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r2, Pascal, in-block-copy1
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 4;
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
@@ -126,14 +124,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t InBlockCopy_ThreadPerDimC = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimH = 4;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimW = 2;
|
||||
constexpr index_t InBlockCopy_ThreadPerDimN = 1;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
@@ -144,9 +134,76 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<0, 0, 0, 0>; // not used
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 1
|
||||
// for 3x3, 34x34, v1r3, Pascal
|
||||
// for 3x3, 28x28, v1r3, Pascal
|
||||
// for 3x3, 14x14, v1r3, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r3, Pascal, bad
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 1;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 32;
|
||||
|
||||
constexpr index_t NPerThread = 1;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 8;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<2, 2, 32, 1>;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 1;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 2;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 1;
|
||||
#elif 0
|
||||
// for 3x3, 34x34, v1r1, Vega 20
|
||||
constexpr index_t NPerBlock = 16;
|
||||
@@ -309,38 +366,6 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 1
|
||||
// for 3x3, 28x28, v1r3, Pascal
|
||||
// for 3x3, 14x14, v1r3, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopyClusterLengths_CHWN = Sequence<8, 2, 2, 4>;
|
||||
constexpr index_t InBlockCopyDataPerRead_N = 4;
|
||||
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 0
|
||||
// for 1x1, 28x28, v1r1, Pascal
|
||||
@@ -419,13 +444,11 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
|
||||
constexpr auto gridwise_conv =
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
#endif
|
||||
<GridSize,
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
#include "device.hpp"
|
||||
#include "gridwise_convolution_wrapper.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r2_nchw_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r3_nchw_cyxk_khwn.hip.hpp"
|
||||
#include "gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp"
|
||||
|
||||
template <class T, class InDesc, class WeiDesc, class OutDesc>
|
||||
void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
@@ -62,7 +64,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
wei_cyxk_device_buf.ToDevice(wei_cyxk.mData.data());
|
||||
out_khwn_device_buf.ToDevice(out_khwn.mData.data());
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// for 3x3, 28x28, v1r2, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
@@ -93,8 +95,78 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 2;
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 4;
|
||||
|
||||
using WeiBlockCopyClusterLengths_CXK = Sequence<4, 1, 32>;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_C = 4;
|
||||
using WeiBlockCopyClusterLengths = Sequence<4, 1, 32>;
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 0
|
||||
// for 3x3, 28x28, v1r3, Pascal, bad
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 16;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 2;
|
||||
|
||||
constexpr index_t NPerThread = 4;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<4, 1, 1, 1>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<4, 8, 2, 2>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#elif 1
|
||||
// for 3x3, 34x34, v1r3, Pascal
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t NPerBlock = 2;
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t CPerBlock = 8;
|
||||
constexpr index_t HoPerBlock = 2;
|
||||
constexpr index_t WoPerBlock = 16;
|
||||
|
||||
constexpr index_t NPerThread = 2;
|
||||
constexpr index_t KPerThread = 8;
|
||||
constexpr index_t HoPerThread = 1;
|
||||
constexpr index_t WoPerThread = 4;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
constexpr index_t GemmKPerThreadLoop = 1;
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockReorderSrcSubLengths_NCHW = Sequence<2, 1, 2, 1>;
|
||||
using InBlockReorderSrcClusterLengths_NCHW = Sequence<1, 8, 1, 16>;
|
||||
using InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW = Sequence<1, 2, 0, 3>;
|
||||
constexpr index_t InBlockReorderDataPerRead_W = 1; // v1r3 cannot do vector load input for NCHW
|
||||
constexpr index_t InBlockReorderDataPerWrite_N = 1; // not used yet
|
||||
|
||||
using WeiBlockCopyClusterLengths = Sequence<0, 0>; // not used
|
||||
constexpr index_t WeiBlockCopyDataPerRead_K = 4;
|
||||
|
||||
constexpr index_t OutThreadCopyDataPerWrite_N = 2;
|
||||
#endif
|
||||
@@ -108,8 +180,12 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
constexpr auto gridwise_conv =
|
||||
#if 1
|
||||
#if 0
|
||||
GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
#elif 1
|
||||
GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
#endif
|
||||
<GridSize,
|
||||
BlockSize,
|
||||
@@ -140,8 +216,8 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyClusterLengths_CXK,
|
||||
WeiBlockCopyDataPerRead_C,
|
||||
WeiBlockCopyClusterLengths,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
OutThreadCopyDataPerWrite_N>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution<decltype(gridwise_conv), T>,
|
||||
|
||||
@@ -46,7 +46,7 @@ struct GeneratorTensor_3
|
||||
#if 0
|
||||
auto f_acc = std::plus<index_t>{};
|
||||
#else
|
||||
auto f_acc = [](auto a, auto b) { return 10 * a + b; };
|
||||
auto f_acc = [](auto a, auto b) { return 100 * a + b; };
|
||||
#endif
|
||||
|
||||
return std::accumulate(dims.begin(), dims.end(), index_t(0), f_acc);
|
||||
@@ -390,8 +390,6 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
|
||||
template <class T>
|
||||
void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
{
|
||||
// printf("\n");
|
||||
|
||||
float error = 0;
|
||||
float max_diff = -1;
|
||||
float ref_value = 0, result_value = 0;
|
||||
@@ -405,10 +403,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
ref_value = ref.mData[i];
|
||||
result_value = result.mData[i];
|
||||
}
|
||||
|
||||
// printf("{%f, %f}", double(ref.mData[i]), double(result.mData[i]));
|
||||
}
|
||||
// printf("\n");
|
||||
|
||||
std::cout << "error: " << error << std::endl;
|
||||
std::cout << "max_diff: " << max_diff << ", " << ref_value << ", " << result_value << std::endl;
|
||||
@@ -416,38 +411,27 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
#if 0
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t HI = 28;
|
||||
constexpr index_t WI = 28;
|
||||
#if 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 56x56
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 56;
|
||||
constexpr index_t WI = 56;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
@@ -499,7 +483,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 1;
|
||||
constexpr index_t WPad = 1;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 5x5 filter, 20x86 image
|
||||
constexpr index_t N = 16;
|
||||
constexpr index_t C = 256;
|
||||
@@ -547,7 +531,7 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 10
|
||||
#elif 0
|
||||
// 1x1 filter, 14x14 image
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 512;
|
||||
@@ -619,9 +603,9 @@ int main(int argc, char* argv[])
|
||||
device_direct_convolution_2_nchw_kcyx_nkhw
|
||||
#elif 0
|
||||
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
|
||||
|
||||
@@ -19,6 +19,20 @@ struct Array
|
||||
__host__ __device__ const TData& operator[](index_t i) const { return mData[i]; }
|
||||
|
||||
__host__ __device__ TData& operator[](index_t i) { return mData[i]; }
|
||||
|
||||
__host__ __device__ auto PushBack(TData x) const
|
||||
{
|
||||
Array<TData, NSize + 1> new_array;
|
||||
|
||||
static_for<0, NSize, 1>{}([=](auto I) {
|
||||
constexpr index_t i = I.Get();
|
||||
new_array[i] = mData[i];
|
||||
});
|
||||
|
||||
new_array[NSize] = x;
|
||||
|
||||
return new_array;
|
||||
}
|
||||
};
|
||||
|
||||
template <class TData, index_t NSize, index_t... IRs>
|
||||
@@ -51,4 +65,4 @@ __host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>&
|
||||
});
|
||||
|
||||
return new_array;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,80 +1,30 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <index_t L0, index_t L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
|
||||
template <class PreviousStrides, class RemainLengths>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths)
|
||||
{
|
||||
return Sequence<L1, 1>{};
|
||||
constexpr index_t previous_stride = PreviousStrides{}.Front();
|
||||
constexpr index_t current_length = RemainLengths{}.Back();
|
||||
constexpr index_t current_stride = current_length * previous_stride;
|
||||
|
||||
return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number<current_stride>{}),
|
||||
RemainLengths{}.PopBack());
|
||||
}
|
||||
|
||||
// this is ugly, only for 3d
|
||||
template <index_t L0, index_t L1, index_t L2>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2>)
|
||||
template <class PreviousStrides, index_t L0, index_t L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence<L0, L1>)
|
||||
{
|
||||
return Sequence<L1 * L2, L2, 1>{};
|
||||
constexpr index_t previous_stride = PreviousStrides{}.Front();
|
||||
constexpr index_t current_stride = L1 * previous_stride;
|
||||
|
||||
return PreviousStrides{}.PushFront(Number<current_stride>{});
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Lengths)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 6d
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
template <index_t L0,
|
||||
index_t L1,
|
||||
index_t L2,
|
||||
index_t L3,
|
||||
index_t L4,
|
||||
index_t L5,
|
||||
index_t L6,
|
||||
index_t L7>
|
||||
__host__ __device__ constexpr auto
|
||||
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7,
|
||||
L2 * L3 * L4 * L5 * L6 * L7,
|
||||
L3 * L4 * L5 * L6 * L7,
|
||||
L4 * L5 * L6 * L7,
|
||||
L5 * L6 * L7,
|
||||
L6 * L7,
|
||||
L7,
|
||||
1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
template <index_t L0,
|
||||
index_t L1,
|
||||
index_t L2,
|
||||
index_t L3,
|
||||
index_t L4,
|
||||
index_t L5,
|
||||
index_t L6,
|
||||
index_t L7,
|
||||
index_t L8,
|
||||
index_t L9>
|
||||
__host__ __device__ constexpr auto
|
||||
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7, L8, L9>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
|
||||
L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
|
||||
L3 * L4 * L5 * L6 * L7 * L8 * L9,
|
||||
L4 * L5 * L6 * L7 * L8 * L9,
|
||||
L5 * L6 * L7 * L8 * L9,
|
||||
L6 * L7 * L8 * L9,
|
||||
L7 * L8 * L9,
|
||||
L8 * L9,
|
||||
L9,
|
||||
1>{};
|
||||
return calculate_default_strides_impl(Sequence<1>{}, Lengths{});
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
@@ -186,6 +136,14 @@ struct ConstantTensorDescriptor
|
||||
return Get1dIndex(multi_id);
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr index_t Get1dIndex(Sequence<Is...> multi_id)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
return Get1dIndex(Is...);
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
@@ -34,11 +34,15 @@ struct Sequence
|
||||
template <index_t... IRs>
|
||||
__host__ __device__ constexpr auto ReorderGivenOld2New(Sequence<IRs...> /*old2new*/) const
|
||||
{
|
||||
// don't know how to implement this
|
||||
// TODO: don't know how to implement this
|
||||
printf("Sequence::ReorderGivenOld2New not implemented");
|
||||
assert(false);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr index_t Front() const { return mData[0]; }
|
||||
|
||||
__host__ __device__ constexpr index_t Back() const { return mData[mSize - 1]; }
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto PushFront(Number<I>) const
|
||||
{
|
||||
@@ -69,15 +73,98 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
|
||||
template <index_t... Is, index_t I>
|
||||
#if 0
|
||||
// TODO: for some reason, compiler cannot instantiate this template
|
||||
template <index_t I, index_t... Is>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
|
||||
{
|
||||
static_assert(sizeof...(Is) > 0, "empty Sequence!");
|
||||
return Sequence<Is...>{};
|
||||
}
|
||||
#else
|
||||
// TODO: delete these very ugly mess
|
||||
template <index_t I0, index_t I1>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1>)
|
||||
{
|
||||
return Sequence<I0>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2>)
|
||||
{
|
||||
return Sequence<I0, I1>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3>)
|
||||
{
|
||||
return Sequence<I0, I1, I2>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4>{};
|
||||
}
|
||||
|
||||
template <index_t I0, index_t I1, index_t I2, index_t I3, index_t I4, index_t I5, index_t I6>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4, I5>{};
|
||||
}
|
||||
|
||||
template <index_t I0,
|
||||
index_t I1,
|
||||
index_t I2,
|
||||
index_t I3,
|
||||
index_t I4,
|
||||
index_t I5,
|
||||
index_t I6,
|
||||
index_t I7>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4, I5, I6>{};
|
||||
}
|
||||
|
||||
template <index_t I0,
|
||||
index_t I1,
|
||||
index_t I2,
|
||||
index_t I3,
|
||||
index_t I4,
|
||||
index_t I5,
|
||||
index_t I6,
|
||||
index_t I7,
|
||||
index_t I8>
|
||||
__host__ __device__ constexpr auto sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4, I5, I6, I7>{};
|
||||
}
|
||||
|
||||
template <index_t I0,
|
||||
index_t I1,
|
||||
index_t I2,
|
||||
index_t I3,
|
||||
index_t I4,
|
||||
index_t I5,
|
||||
index_t I6,
|
||||
index_t I7,
|
||||
index_t I8,
|
||||
index_t I9>
|
||||
__host__ __device__ constexpr auto
|
||||
sequence_pop_back(Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8, I9>)
|
||||
{
|
||||
return Sequence<I0, I1, I2, I3, I4, I5, I6, I7, I8>{};
|
||||
}
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// this is ugly, only for 2 sequences
|
||||
// TODO: fix these mess
|
||||
template <class F, index_t... Xs, index_t... Ys>
|
||||
__host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>)
|
||||
{
|
||||
@@ -86,7 +173,6 @@ __host__ __device__ constexpr auto transform_sequences(F f, Sequence<Xs...>, Seq
|
||||
return Sequence<f(Xs, Ys)...>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 3 sequences
|
||||
template <class F, index_t... Xs, index_t... Ys, index_t... Zs>
|
||||
__host__ __device__ constexpr auto
|
||||
transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
@@ -98,6 +184,7 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
|
||||
return Sequence<f(Xs, Ys, Zs)...>{};
|
||||
}
|
||||
#else
|
||||
// TODO:: these doesn't compile
|
||||
template <index_t NRemain>
|
||||
struct transform_sequences_impl
|
||||
{
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#pragma once
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
|
||||
template <index_t BlockSize, class Float, class DstDesc, class F>
|
||||
__device__ void
|
||||
@@ -957,6 +958,7 @@ struct Blockwise4dTensorCopyReorder3
|
||||
constexpr auto thread_sub_tensor_desc =
|
||||
make_ConstantTensorDescriptor(SrcClusterLengths{}, thread_tensor_desc.GetStrides());
|
||||
|
||||
#if 1
|
||||
for(index_t icluster_d0 = 0; icluster_d0 < cluster_per_dims.Get(I0); ++icluster_d0)
|
||||
{
|
||||
for(index_t icluster_d1 = 0; icluster_d1 < cluster_per_dims.Get(I1); ++icluster_d1)
|
||||
@@ -978,16 +980,21 @@ struct Blockwise4dTensorCopyReorder3
|
||||
icluster_d2 * thread_sub_tensor_lengths.Get(I2),
|
||||
icluster_d3 * thread_sub_tensor_lengths.Get(I3));
|
||||
|
||||
threadwise_4d_tensor_copy_v2(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
threadwise_nd_tensor_copy(SrcDesc{},
|
||||
p_src + src_offset + mSrcMyThreadOffset,
|
||||
thread_tensor_desc,
|
||||
p_clipboard + clipboard_offset,
|
||||
thread_sub_tensor_lengths,
|
||||
Number<SrcDataPerRead>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
static_ford<decltype(cluster_per_dims)>{}([=](auto cluster_ids) {
|
||||
|
||||
});
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
if(get_block_1d_id() == 0)
|
||||
|
||||
@@ -253,9 +253,23 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
|
||||
p_a_thread[0], p_a_thread[1], p_a_thread[2], p_a_thread[3], p_a_thread[4], p_a_thread[5], p_a_thread[6], p_a_thread[7],
|
||||
p_b_thread[0], p_b_thread[1], p_b_thread[2], p_b_thread[3], p_b_thread[4], p_b_thread[5], p_b_thread[6], p_b_thread[7]);
|
||||
printf("a: %f %f %f %f %f %f %f %f, b: %f %f %f %f %f %f %f %f\n",
|
||||
p_a_thread[0],
|
||||
p_a_thread[1],
|
||||
p_a_thread[2],
|
||||
p_a_thread[3],
|
||||
p_a_thread[4],
|
||||
p_a_thread[5],
|
||||
p_a_thread[6],
|
||||
p_a_thread[7],
|
||||
p_b_thread[0],
|
||||
p_b_thread[1],
|
||||
p_b_thread[2],
|
||||
p_b_thread[3],
|
||||
p_b_thread[4],
|
||||
p_b_thread[5],
|
||||
p_b_thread[6],
|
||||
p_b_thread[7]);
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "Sequence.hip.hpp"
|
||||
#include "Array.hip.hpp"
|
||||
#include "functional.hip.hpp"
|
||||
#include "functional2.hip.hpp"
|
||||
|
||||
#if DEVICE_BACKEND_HIP
|
||||
#include "amd_inline_asm.hip.hpp"
|
||||
|
||||
@@ -21,7 +21,7 @@ struct static_for_impl<Iter, 0, Increment>
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F) const
|
||||
{
|
||||
// do nothing
|
||||
// no work left, just return
|
||||
return;
|
||||
}
|
||||
};
|
||||
@@ -48,7 +48,7 @@ struct static_const_reduce_n
|
||||
static_assert(NLoop > 1, "out-of-range");
|
||||
|
||||
constexpr auto a = f(Number<NLoop - 1>{});
|
||||
auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // cannot use constexpr here, weird
|
||||
auto b = static_const_reduce_n<NLoop - 1>{}(f, r); // TODO: cannot use constexpr here, weird
|
||||
return r(a, b);
|
||||
}
|
||||
};
|
||||
@@ -70,3 +70,65 @@ __host__ __device__ constexpr auto unpacker(F f)
|
||||
return [=](auto xs_array){ f(xs...); };
|
||||
}
|
||||
#endif
|
||||
|
||||
struct forwarder
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T operator()(T&& x) const
|
||||
{
|
||||
return std::forward<T>(x);
|
||||
}
|
||||
};
|
||||
|
||||
// Emulate compile time if statement for C++14
|
||||
// Get the idea from
|
||||
// "https://baptiste-wicht.com/posts/2015/07/simulate-static_if-with-c11c14.html"
|
||||
// TODO: use if constexpr, when C++17 is supported
|
||||
template <bool Predicate>
|
||||
struct static_if
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_if<true>
|
||||
{
|
||||
using Type = static_if<true>;
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ constexpr auto operator()(F f) const
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
|
||||
// this will make "f" a generic lambda, so that "f" won't be compiled until here
|
||||
f(forwarder{});
|
||||
return Type{};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ static constexpr auto else_(F)
|
||||
{
|
||||
return Type{};
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_if<false>
|
||||
{
|
||||
using Type = static_if<false>;
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ constexpr auto operator()(F) const
|
||||
{
|
||||
return Type{};
|
||||
}
|
||||
|
||||
template <class F>
|
||||
__host__ __device__ static constexpr auto else_(F f)
|
||||
{
|
||||
// This is a trick for compiler:
|
||||
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
|
||||
// this will make "f" a generic lambda, so that "f" won't be compiled until here
|
||||
f(forwarder{});
|
||||
return Type{};
|
||||
}
|
||||
};
|
||||
|
||||
117
src/include/functional2.hip.hpp
Normal file
117
src/include/functional2.hip.hpp
Normal file
@@ -0,0 +1,117 @@
|
||||
#pragma once
|
||||
#include "Sequence.hip.hpp"
|
||||
|
||||
template <index_t RemainDim>
|
||||
struct static_ford_impl
|
||||
{
|
||||
// F signature: F(Sequence<...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
|
||||
static_assert(RemainDim > 1, "wrong!");
|
||||
|
||||
constexpr auto next_length = RemainLengths{}.Front();
|
||||
|
||||
static_for<0, next_length, 1>{}([=](auto I) {
|
||||
static_ford_impl<RemainDim - 1>{}(
|
||||
f, CurrentMultiIndex{}.PushBack(I), RemainLengths{}.PopFront());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct static_ford_impl<1>
|
||||
{
|
||||
// F signature: F(Sequence<Is...> multi_id)
|
||||
// CurrentMultiIndex: Sequence<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void operator()(F f, CurrentMultiIndex, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == 1, "wrong!");
|
||||
|
||||
constexpr index_t last_length = RemainLengths{}.Front();
|
||||
|
||||
static_for<0, last_length, 1>{}([=](auto I) { f(CurrentMultiIndex{}.PushBack(I)); });
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>
|
||||
template <class Lengths>
|
||||
struct static_ford
|
||||
{
|
||||
// F signature: F(Sequence<Is...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
constexpr index_t first_length = Lengths{}.Front();
|
||||
|
||||
static_for<0, first_length, 1>{}([=](auto I) {
|
||||
static_ford_impl<Lengths::GetSize() - 1>{}(
|
||||
f, Sequence<I.Get()>{}, Lengths{}.PopFront());
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
template <index_t RemainDim>
|
||||
struct ford_impl
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == RemainDim, "wrong!");
|
||||
static_assert(RemainDim > 1, "wrong!");
|
||||
|
||||
constexpr auto next_length = RemainLengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < next_length; ++i)
|
||||
{
|
||||
ford_impl<RemainDim - 1>{}(f, current_multi_id.PushBack(i), RemainLengths{}.PopFront());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ford_impl<1>
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
// CurrentMultiIndex: Array<...>
|
||||
// RemainLengths: Sequence<...>
|
||||
template <class F, class CurrentMultiIndex, class RemainLengths>
|
||||
__host__ __device__ void
|
||||
operator()(F f, CurrentMultiIndex current_multi_id, RemainLengths) const
|
||||
{
|
||||
static_assert(RemainLengths::GetSize() == 1, "wrong!");
|
||||
|
||||
constexpr index_t last_length = RemainLengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < last_length; ++i)
|
||||
{
|
||||
f(current_multi_id.PushBack(i));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Lengths is Sequence<...>
|
||||
template <class Lengths>
|
||||
struct ford
|
||||
{
|
||||
// F signature: F(Array<...> multi_id)
|
||||
template <class F>
|
||||
__host__ __device__ void operator()(F f) const
|
||||
{
|
||||
constexpr index_t first_length = Lengths{}.Front();
|
||||
|
||||
for(index_t i = 0; i < first_length; ++i)
|
||||
{
|
||||
ford_impl<Lengths::GetSize() - 1>{}(f, Array<index_t, 1>{i}, Lengths{}.PopFront());
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -32,10 +32,10 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerRead_N,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
@@ -43,28 +43,30 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
|
||||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
constexpr auto in_c_h_w_n_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
constexpr index_t C = in_c_h_w_n_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
@@ -95,24 +97,35 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// flattend (2d) tensor view of gridwise weight
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
constexpr auto wei_cyx_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(
|
||||
InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
|
||||
constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<max_align>{});
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<max_align>{});
|
||||
constexpr auto wei_cyx_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
constexpr auto wei_c_y_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor(
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
@@ -121,18 +134,18 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
#if 0
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#else
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDims,
|
||||
InBlockCopyDataPerRead>{};
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
@@ -140,10 +153,10 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
decltype(wei_cyx_k_global_desc),
|
||||
decltype(wei_cyx_k_block_desc),
|
||||
decltype(wei_cyx_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -151,28 +164,30 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
|
||||
constexpr auto a_c_k_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_c_y_x_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_cxwn_block_mtx_desc =
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_chwn_block_desc.GetStride(I0)>{});
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxwn_thread_mtx_desc =
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_khwn_thread_desc.GetStride(I0)>{});
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_khwn_thread_desc.GetStride(I1),
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
@@ -186,30 +201,33 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
|
||||
wei_c_y_x_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_chwn_global_desc.Get1dIndex(
|
||||
p_in_global + in_c_h_w_n_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_chwn_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0),
|
||||
p_in_global_block_offset += CPerBlock * in_c_h_w_n_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
#if 1
|
||||
@@ -241,96 +259,140 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
|
||||
#else
|
||||
blockwise_batch_gemm.Run_asm
|
||||
#endif
|
||||
(p_wei_block + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
|
||||
(p_wei_block + wei_c_y_x_k_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block + in_c_h_w_n_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
#if 0
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
{
|
||||
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
{
|
||||
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
|
||||
const index_t ho_thread =
|
||||
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
|
||||
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
|
||||
const index_t wo_thread = b_thread / NPerBlock;
|
||||
const index_t n_thread = b_thread % NPerBlock;
|
||||
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
wo_block_data_begin + wo_thread,
|
||||
n_block_data_begin + n_thread)] =
|
||||
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin =
|
||||
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_khwn_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -1,407 +0,0 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r1_lds_double_buffer_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
|
||||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_chwn_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_cyxk_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_khwn_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_chwn_global_desc.GetLength(I0);
|
||||
|
||||
constexpr index_t K = out_khwn_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_khwn_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_khwn_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_khwn_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_cyxk_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_cyxk_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// assert for LDS double buffer
|
||||
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// flattend (2d) tensor view of gridwise weight
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * Y * X, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight in LDS
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align =
|
||||
mod_conv::max(index_t(4), InBlockCopyDataPerRead, WeiBlockCopyDataPerRead);
|
||||
|
||||
constexpr auto in_chwn_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HiPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * Y * X, KPerBlock>{}, Number<max_align>{});
|
||||
|
||||
constexpr auto wei_cyxk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, Y, X, KPerBlock>{}, Number<max_align>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_khwn_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDims,
|
||||
InBlockCopyDataPerRead>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*Y*X,KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,Y,X,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_cyxk_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_cxwn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_chwn_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_kxwn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_khwn_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_khwn_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space = in_chwn_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
constexpr index_t wei_block_space =
|
||||
wei_cyxk_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_chwn_global_desc.Get1dIndex(
|
||||
0, hi_block_data_begin, wi_block_data_begin, n_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_cyxk_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
// preload data into LDS
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard, p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_khwn_thread_desc.GetElementSpace()];
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_khwn_thread_desc, p_out_thread);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
// load next data
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
p_in_global_block_offset += CPerBlock * in_chwn_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// a series of batched GEMM
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(
|
||||
p_wei_block_now + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_now + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
// even
|
||||
p_in_global_block_offset += CPerBlock * in_chwn_global_desc.GetStride(I0);
|
||||
p_wei_global_block_offset += CPerBlock * wei_cyxk_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
Float p_in_register_clipboard[blockwise_in_copy.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(
|
||||
p_wei_block_double + wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_double + in_chwn_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
|
||||
blockwise_in_copy.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double + in_block_space);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd
|
||||
__syncthreads();
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space +
|
||||
wei_cyxk_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_in_block_double + in_block_space +
|
||||
in_chwn_block_desc.Get1dIndex(0, y, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
#if 0
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
for(index_t k = 0; k < out_khwn_thread_desc.GetLength(I0); ++k)
|
||||
{
|
||||
for(index_t ho = 0; ho < out_khwn_thread_desc.GetLength(I1); ++ho)
|
||||
{
|
||||
for(index_t wo = 0; wo < out_khwn_thread_desc.GetLength(I2); ++wo)
|
||||
{
|
||||
for(index_t n = 0; n < out_khwn_thread_desc.GetLength(I3); ++n)
|
||||
{
|
||||
const index_t b = out_khwn_thread_desc.Get1dIndex(0, 0, wo, n);
|
||||
|
||||
const auto c_thread_mtx_distance =
|
||||
blockwise_batch_gemm.GetDistanceFromBeginOfThreadMatrixC(ho, k, b);
|
||||
|
||||
const index_t ho_thread =
|
||||
c_thread_mtx_begin.batch + c_thread_mtx_distance.batch;
|
||||
const index_t k_thread = c_thread_mtx_begin.row + c_thread_mtx_distance.row;
|
||||
const index_t b_thread = c_thread_mtx_begin.col + c_thread_mtx_distance.col;
|
||||
|
||||
const index_t wo_thread = b_thread / NPerBlock;
|
||||
const index_t n_thread = b_thread % NPerBlock;
|
||||
|
||||
p_out_global[out_khwn_global_desc.Get1dIndex(k_block_data_begin + k_thread,
|
||||
ho_block_data_begin + ho_thread,
|
||||
wo_block_data_begin + wo_thread,
|
||||
n_block_data_begin + n_thread)] =
|
||||
p_out_thread[out_khwn_thread_desc.Get1dIndex(k, ho, wo, n)];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#elif 1
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin =
|
||||
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_khwn_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
#endif
|
||||
}
|
||||
};
|
||||
@@ -33,10 +33,10 @@ template <index_t GridSize,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockCopyThreadPerDims,
|
||||
index_t InBlockCopyDataPerRead,
|
||||
index_t WeiBlockCopyDataPerRead,
|
||||
index_t OutThreadCopyDataPerWrite>
|
||||
class InBlockCopyClusterLengths_CHWN,
|
||||
index_t InBlockCopyDataPerRead_N,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
@@ -44,9 +44,11 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
|
||||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -101,14 +103,23 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(
|
||||
InBlockCopyDataPerRead, WeiBlockCopyDataPerRead, GemmDataPerReadA, GemmDataPerReadB);
|
||||
constexpr index_t max_align = mod_conv::max(InBlockCopyDataPerRead_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
Sequence<CPerBlock, HoPerBlock, WiPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_x_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, X, KPerBlock>{}, Number<max_align>{});
|
||||
Sequence<CPerBlock, X, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
@@ -116,14 +127,14 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
#if 0
|
||||
#if 1
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#else
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise4dTensorCopy3<BlockSize,
|
||||
@@ -131,8 +142,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
decltype(in_c_h_w_n_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
decltype(in_c_h_w_n_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDims,
|
||||
InBlockCopyDataPerRead>{};
|
||||
InBlockCopyClusterLengths_CHWN,
|
||||
InBlockCopyDataPerRead_N>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
@@ -143,7 +154,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
decltype(wei_c_x_k_global_desc),
|
||||
decltype(wei_c_x_k_block_desc),
|
||||
decltype(wei_c_x_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -195,7 +206,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -293,46 +306,126 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin =
|
||||
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite>{});
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -39,7 +39,7 @@ template <index_t GridSize,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_CXK,
|
||||
index_t WeiBlockCopyDataPerRead_C,
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
{
|
||||
@@ -106,7 +106,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_C,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
@@ -146,7 +146,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
decltype(wei_c_x_k_block_desc),
|
||||
decltype(wei_c_x_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyClusterLengths_CXK,
|
||||
WeiBlockCopyDataPerRead_C>{};
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -216,6 +216,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 0
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
@@ -229,7 +230,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
#if 1
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0),
|
||||
p_in_block);
|
||||
@@ -237,23 +237,6 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0),
|
||||
p_wei_block);
|
||||
#else
|
||||
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
|
||||
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(
|
||||
p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, 0),
|
||||
p_in_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(
|
||||
p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, 0),
|
||||
p_wei_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
|
||||
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
|
||||
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -268,6 +251,49 @@ struct GridwiseConvolutionImplicitGemm_v1r2_nchw_cyxk_khwn
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
|
||||
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_c_x_k_block_desc.Get1dIndex(0, x, 0),
|
||||
p_in_block +
|
||||
in_c_h_w_n_block_desc.Get1dIndex(0, 0, x, 0),
|
||||
p_out_thread);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
|
||||
@@ -43,9 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
|
||||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -66,9 +68,6 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
constexpr index_t HiPerBlock = HoPerBlock + Y - 1;
|
||||
constexpr index_t WiPerBlock = WoPerBlock + X - 1;
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
@@ -106,10 +105,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{});
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
@@ -177,6 +183,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
// TODO:: need to properly implement tensor descriptor with alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
@@ -185,7 +192,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -276,46 +285,126 @@ struct GridwiseConvolutionImplicitGemm_v1r3_chwn_cyxk_khwn
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin =
|
||||
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -43,9 +43,11 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(
|
||||
NPerThread <= NPerBlock && NPerBlock % NPerThread == 0,
|
||||
"wrong! should satisfy: NPerThread <= NPerBlock && NPerBlock % NPerThread == 0");
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
@@ -109,10 +111,17 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{}, Number<max_align>{});
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockCopyDataPerRead_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{}, Number<max_align>{});
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
@@ -199,7 +208,9 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register
|
||||
Float p_out_thread[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
@@ -336,46 +347,126 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin =
|
||||
c_thread_mtx_begin.col - NPerBlock * wo_thread_data_begin;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr index_t W2 =
|
||||
(GemmNLevel0Cluster * GemmNLevel1Cluster) / (NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr auto out_10d_global_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<K / (K1 * K2), K1, K2, Ho, Wo / (W1 * W2), W1, W2, N / (N1 * N2), N1, N2>{});
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_khwn_thread_desc, "out_khwn_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_khwn_global_desc, "out_khwn_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_10d_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global + out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
@@ -0,0 +1,472 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockReorderSrcSubLengths_NCHW,
|
||||
class InBlockReorderSrcClusterLengths_NCHW,
|
||||
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_CK, // not used
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// assert for LDS double buffer
|
||||
static_assert(C % (2 * CPerBlock) == 0, "C cannot be evenly divided");
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockReorderDataPerWrite_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
// LDS double buffer
|
||||
__shared__ Float p_in_block_double[2 * in_block_space];
|
||||
__shared__ Float p_wei_block_double[2 * wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
// LDS double buffer: preload data into LDS
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy_reorder
|
||||
.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_double);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_double);
|
||||
}
|
||||
|
||||
// LDS double buffer: main body
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin + 2 * CPerBlock < C;
|
||||
c_block_data_begin += 2 * CPerBlock)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop = 0; iloop < 2; ++iloop)
|
||||
{
|
||||
const bool even_loop = (iloop % 2 == 0);
|
||||
|
||||
Float* p_in_block_now =
|
||||
even_loop ? p_in_block_double : p_in_block_double + in_block_space;
|
||||
Float* p_wei_block_now =
|
||||
even_loop ? p_wei_block_double : p_wei_block_double + wei_block_space;
|
||||
|
||||
Float* p_in_block_next =
|
||||
even_loop ? p_in_block_double + in_block_space : p_in_block_double;
|
||||
Float* p_wei_block_next =
|
||||
even_loop ? p_wei_block_double + wei_block_space : p_wei_block_double;
|
||||
|
||||
Float p_in_register_clipboard[blockwise_in_copy_reorder
|
||||
.GetRegisterClipboardSize()];
|
||||
Float
|
||||
p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_batch_gemm.Run(p_wei_block_now, p_in_block_now, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_register_clipboard,
|
||||
p_in_block_next);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_register_clipboard,
|
||||
p_wei_block_next);
|
||||
}
|
||||
}
|
||||
|
||||
// LDS double buffer: tail
|
||||
{
|
||||
Float p_in_register_clipboard[blockwise_in_copy_reorder
|
||||
.GetRegisterClipboardSize()];
|
||||
Float p_wei_register_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
// even iteration
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1);
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// LDS doubel buffer: load next data from device mem
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_register_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_register_clipboard);
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_batch_gemm.Run(p_wei_block_double, p_in_block_double, p_out_thread);
|
||||
|
||||
// LDS double buffer: store next data to LDS
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(
|
||||
p_in_register_clipboard, p_in_block_double + in_block_space);
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(
|
||||
p_wei_register_clipboard, p_wei_block_double + wei_block_space);
|
||||
|
||||
// odd iteration
|
||||
__syncthreads();
|
||||
|
||||
// LDS double buffer: GEMM on current data
|
||||
blockwise_batch_gemm.Run(p_wei_block_double + wei_block_space,
|
||||
p_in_block_double + in_block_space,
|
||||
p_out_thread);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -0,0 +1,452 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
#include "ConstantTensorDescriptor.hip.hpp"
|
||||
#include "ConstantMatrixDescriptor.hip.hpp"
|
||||
#include "blockwise_2d_tensor_op.hip.hpp"
|
||||
#include "blockwise_4d_tensor_op.hip.hpp"
|
||||
#include "threadwise_nd_tensor_op.hip.hpp"
|
||||
#include "threadwise_4d_tensor_op.hip.hpp"
|
||||
#include "blockwise_batched_gemm.hip.hpp"
|
||||
|
||||
template <index_t GridSize,
|
||||
index_t BlockSize,
|
||||
class Float,
|
||||
class InGlobalDesc,
|
||||
class WeiGlobalDesc,
|
||||
class OutGlobalDesc,
|
||||
index_t NPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t CPerBlock,
|
||||
index_t HoPerBlock,
|
||||
index_t WoPerBlock,
|
||||
index_t NPerThread,
|
||||
index_t KPerThread,
|
||||
index_t HoPerThread,
|
||||
index_t WoPerThread,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
index_t GemmNLevel0Cluster,
|
||||
index_t GemmMLevel1Cluster,
|
||||
index_t GemmNLevel1Cluster,
|
||||
index_t GemmKPerThreadLoop,
|
||||
index_t GemmDataPerReadA,
|
||||
index_t GemmDataPerReadB,
|
||||
class InBlockReorderSrcSubLengths_NCHW,
|
||||
class InBlockReorderSrcClusterLengths_NCHW,
|
||||
class InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
index_t InBlockReorderDataPerRead_W,
|
||||
index_t InBlockReorderDataPerWrite_N,
|
||||
class WeiBlockCopyClusterLengths_CK, // not used
|
||||
index_t WeiBlockCopyDataPerRead_K,
|
||||
index_t OutThreadCopyDataPerWrite_N>
|
||||
struct GridwiseConvolutionImplicitGemm_v1r3_nchw_cyxk_khwn
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// be careful of this assertion
|
||||
static_assert(NPerBlock % NPerThread == 0 && (GemmNPerThreadSubC <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0) ||
|
||||
(GemmNPerThreadSubC >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0),
|
||||
"wrong!");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_n_c_h_w_global_desc = InGlobalDesc{};
|
||||
constexpr auto wei_c_y_x_k_global_desc = WeiGlobalDesc{};
|
||||
constexpr auto out_k_h_w_n_global_desc = OutGlobalDesc{};
|
||||
|
||||
constexpr index_t C = in_n_c_h_w_global_desc.GetLength(I1);
|
||||
|
||||
constexpr index_t K = out_k_h_w_n_global_desc.GetLength(I0);
|
||||
constexpr index_t Ho = out_k_h_w_n_global_desc.GetLength(I1);
|
||||
constexpr index_t Wo = out_k_h_w_n_global_desc.GetLength(I2);
|
||||
constexpr index_t N = out_k_h_w_n_global_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t Y = wei_c_y_x_k_global_desc.GetLength(I1);
|
||||
constexpr index_t X = wei_c_y_x_k_global_desc.GetLength(I2);
|
||||
|
||||
// divide block work: [K, Ho, Wo, N]
|
||||
static_assert(N % NPerBlock == 0 && K % KPerBlock == 0 && C % CPerBlock == 0 &&
|
||||
Ho % HoPerBlock == 0 && Wo % WoPerBlock == 0,
|
||||
"wrong! cannot evenly divide work for workgroup ");
|
||||
|
||||
constexpr index_t KBlockWork = (K + KPerBlock - 1) / KPerBlock;
|
||||
constexpr index_t HBlockWork = (Ho + HoPerBlock - 1) / HoPerBlock;
|
||||
constexpr index_t WBlockWork = (Wo + WoPerBlock - 1) / WoPerBlock;
|
||||
constexpr index_t NBlockWork = (N + NPerBlock - 1) / NPerBlock;
|
||||
|
||||
const index_t k_block_work_id = get_block_1d_id() / (HBlockWork * WBlockWork * NBlockWork);
|
||||
index_t itmp = get_block_1d_id() - k_block_work_id * (HBlockWork * WBlockWork * NBlockWork);
|
||||
const index_t h_block_work_id = itmp / (WBlockWork * NBlockWork);
|
||||
itmp -= h_block_work_id * (WBlockWork * NBlockWork);
|
||||
const index_t w_block_work_id = itmp / NBlockWork;
|
||||
const index_t n_block_work_id = itmp - w_block_work_id * NBlockWork;
|
||||
|
||||
const index_t k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const index_t ho_block_data_begin = h_block_work_id * HoPerBlock;
|
||||
const index_t wo_block_data_begin = w_block_work_id * WoPerBlock;
|
||||
const index_t n_block_data_begin = n_block_work_id * NPerBlock;
|
||||
|
||||
const index_t hi_block_data_begin = ho_block_data_begin;
|
||||
const index_t wi_block_data_begin = wo_block_data_begin;
|
||||
|
||||
// global tensor view
|
||||
constexpr auto wei_c_k_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<C, K>{}, Sequence<Y * X * K, 1>{});
|
||||
|
||||
// LDS tensor view
|
||||
// be careful of alignment
|
||||
constexpr index_t max_align = mod_conv::max(InBlockReorderDataPerWrite_N,
|
||||
WeiBlockCopyDataPerRead_K,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB);
|
||||
|
||||
constexpr auto in_c_h_w_n_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, HoPerBlock, WoPerBlock, NPerBlock>{},
|
||||
Number<InBlockReorderDataPerWrite_N>{});
|
||||
|
||||
// this check is ad-hoc
|
||||
// TODO: need to properly implement tensor descriptor with alignment
|
||||
static_assert(in_c_h_w_n_block_desc.GetStride(I1) % GemmDataPerReadB == 0,
|
||||
"GemmDataPerReadB alignment requirement is not meet");
|
||||
|
||||
constexpr auto wei_c_k_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, KPerBlock>{},
|
||||
Number<mod_conv::max(WeiBlockCopyDataPerRead_K, GemmDataPerReadA)>{});
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_k_h_w_n_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread, HoPerThread, WoPerThread, NPerThread>{});
|
||||
|
||||
// blockwise copy
|
||||
// input: format is [N, C, Hi, Wi] to [C, Hi, Wi, N]
|
||||
constexpr auto map_chwn2nchw = Sequence<1, 2, 3, 0>{};
|
||||
|
||||
const auto blockwise_in_copy_reorder =
|
||||
Blockwise4dTensorCopyReorder3<BlockSize,
|
||||
Float,
|
||||
decltype(in_n_c_h_w_global_desc),
|
||||
decltype(in_c_h_w_n_block_desc),
|
||||
Sequence<NPerBlock, CPerBlock, HoPerBlock, WoPerBlock>,
|
||||
InBlockReorderSrcSubLengths_NCHW,
|
||||
InBlockReorderSrcClusterLengths_NCHW,
|
||||
decltype(map_chwn2nchw),
|
||||
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW,
|
||||
InBlockReorderDataPerRead_W,
|
||||
InBlockReorderDataPerWrite_N>{};
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock, KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_c_k_global_desc),
|
||||
decltype(wei_c_k_block_desc),
|
||||
decltype(wei_c_k_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead_K>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
|
||||
constexpr auto a_c_k_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}, Number<wei_c_k_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto b_c_wn_block_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<CPerBlock>{},
|
||||
Number<WoPerBlock * NPerBlock>{},
|
||||
Number<in_c_h_w_n_block_desc.GetStride(I0)>{});
|
||||
|
||||
constexpr auto c_k_wn_thread_mtx_desc =
|
||||
make_ConstantMatrixDescriptor(Number<KPerThread>{},
|
||||
Number<WoPerThread * NPerThread>{},
|
||||
Number<out_k_h_w_n_thread_desc.GetStride(I0)>{});
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2<
|
||||
BlockSize,
|
||||
decltype(a_c_k_block_mtx_desc),
|
||||
decltype(b_c_wn_block_mtx_desc),
|
||||
decltype(c_k_wn_thread_mtx_desc),
|
||||
0,
|
||||
in_c_h_w_n_block_desc.GetStride(I1),
|
||||
out_k_h_w_n_thread_desc.GetStride(I1),
|
||||
HoPerBlock,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmKPerThreadLoop,
|
||||
HoPerThread,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB>{};
|
||||
|
||||
// LDS: be careful of alignment
|
||||
constexpr index_t in_block_space =
|
||||
in_c_h_w_n_block_desc.GetElementSpace(Number<max_align>{});
|
||||
constexpr index_t wei_block_space = wei_c_k_block_desc.GetElementSpace(Number<max_align>{});
|
||||
|
||||
__shared__ Float p_in_block[in_block_space];
|
||||
__shared__ Float p_wei_block[wei_block_space];
|
||||
|
||||
// register
|
||||
// C++ lambda doesn't capture array, use pointer instead
|
||||
Float p_out_thread_data[out_k_h_w_n_thread_desc.GetElementSpace()];
|
||||
Float* const p_out_thread = p_out_thread_data;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_global_desc, "in_c_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_y_x_k_global_desc, "wei_c_y_x_k_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_c_h_w_n_block_desc, "in_c_h_w_n_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_c_k_block_desc, "wei_c_k_block_desc");
|
||||
|
||||
printf("in_block_space %u, wei_block_space %u\n", in_block_space, wei_block_space);
|
||||
}
|
||||
#endif
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_k_h_w_n_thread_desc, p_out_thread);
|
||||
|
||||
#if 1
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global + in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin, wi_block_data_begin);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset += CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
#if 1
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x),
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset +
|
||||
wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0),
|
||||
p_wei_block);
|
||||
#else
|
||||
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
|
||||
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(
|
||||
p_in_global_block_offset + in_n_c_h_w_global_desc.Get1dIndex(0, 0, y, x),
|
||||
p_in_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(
|
||||
p_wei_global_block_offset + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, 0),
|
||||
p_wei_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
|
||||
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
|
||||
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
for(index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
for(index_t x = 0; x < X; ++x)
|
||||
{
|
||||
const Float* p_in_global_block_offset =
|
||||
p_in_global +
|
||||
in_n_c_h_w_global_desc.Get1dIndex(
|
||||
n_block_data_begin, 0, hi_block_data_begin + y, wi_block_data_begin + x);
|
||||
|
||||
const Float* p_wei_global_block_offset =
|
||||
p_wei_global + wei_c_y_x_k_global_desc.Get1dIndex(0, y, x, k_block_data_begin);
|
||||
|
||||
for(index_t c_block_data_begin = 0; c_block_data_begin < C;
|
||||
c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset +=
|
||||
CPerBlock * in_n_c_h_w_global_desc.GetStride(I1),
|
||||
p_wei_global_block_offset +=
|
||||
CPerBlock * wei_c_y_x_k_global_desc.GetStride(I0))
|
||||
{
|
||||
#if 0
|
||||
blockwise_in_copy_reorder.Run(p_in_global_block_offset,
|
||||
p_in_block);
|
||||
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset,
|
||||
p_wei_block);
|
||||
#else
|
||||
Float p_in_clipboard[blockwise_in_copy_reorder.GetRegisterClipboardSize()];
|
||||
Float p_wei_clipboard[blockwise_wei_copy.GetRegisterClipboardSize()];
|
||||
|
||||
blockwise_in_copy_reorder.RunLoadRegisterClipboard(p_in_global_block_offset,
|
||||
p_in_clipboard);
|
||||
blockwise_wei_copy.RunLoadRegisterClipboard(p_wei_global_block_offset,
|
||||
p_wei_clipboard);
|
||||
|
||||
blockwise_wei_copy.RunStoreRegisterClipboard(p_wei_clipboard, p_wei_block);
|
||||
blockwise_in_copy_reorder.RunStoreRegisterClipboard(p_in_clipboard, p_in_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
|
||||
blockwise_batch_gemm.Run(p_wei_block, p_in_block, p_out_thread);
|
||||
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
// output: register to global mem,
|
||||
const auto c_thread_mtx_begin =
|
||||
blockwise_batch_gemm.GetBeginOfThreadMatrixC(get_thread_local_1d_id());
|
||||
|
||||
const index_t k_thread_data_begin = c_thread_mtx_begin.row;
|
||||
const index_t ho_thread_data_begin = c_thread_mtx_begin.batch;
|
||||
const index_t wo_thread_data_begin = c_thread_mtx_begin.col / NPerBlock;
|
||||
const index_t n_thread_data_begin = c_thread_mtx_begin.col % NPerBlock;
|
||||
|
||||
static_if<GemmNPerThreadSubC <= NPerBlock>{}(
|
||||
[&](auto f_dummy) { // f_dummy do nothing but perfect forwarding. Using this trick to
|
||||
// make this lambda a generic lambda, so it won't be compiled until
|
||||
// instantiated
|
||||
static_assert((f_dummy(GemmNPerThreadSubC) <= NPerBlock &&
|
||||
NPerBlock % GemmNPerThreadSubC == 0),
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
constexpr index_t N1 = NPerBlock / N2;
|
||||
|
||||
constexpr index_t W2 = (GemmNLevel0Cluster * GemmNLevel1Cluster) /
|
||||
f_dummy(NPerBlock / GemmNPerThreadSubC);
|
||||
constexpr index_t W1 = WoPerBlock / W2;
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2),
|
||||
W1,
|
||||
W2,
|
||||
N / f_dummy(N1 * N2),
|
||||
N1,
|
||||
N2>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, 1, 1, N2>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
})
|
||||
.else_([&](auto f_dummy) {
|
||||
static_assert(f_dummy(GemmNPerThreadSubC) >= NPerBlock && NPerThread == NPerBlock &&
|
||||
GemmNPerThreadSubC % NPerThread == 0,
|
||||
"wrong!");
|
||||
|
||||
// output is a 10d tensor
|
||||
constexpr index_t N1 = NPerBlock;
|
||||
|
||||
constexpr index_t W3 = GemmNPerThreadSubC / NPerBlock;
|
||||
constexpr index_t W2 = GemmNLevel0Cluster * GemmNLevel1Cluster;
|
||||
constexpr index_t W1 = WoPerBlock / f_dummy(W2 * W3);
|
||||
|
||||
constexpr index_t K2 = GemmMPerThreadSubC;
|
||||
constexpr index_t K1 = KPerBlock / KPerThread;
|
||||
|
||||
constexpr auto out_10d_global_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<K / (K1 * K2),
|
||||
K1,
|
||||
K2,
|
||||
Ho,
|
||||
Wo / (W1 * W2 * W3),
|
||||
W1,
|
||||
W2,
|
||||
W3,
|
||||
N / N1,
|
||||
N1>{});
|
||||
|
||||
constexpr auto out_10d_thread_desc = make_ConstantTensorDescriptor(
|
||||
Sequence<KPerThread / K2, 1, K2, HoPerThread, 1, W1, 1, W3, 1, N1>{});
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
|
||||
"out_k_h_w_n_thread_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
|
||||
"out_k_h_w_n_global_desc");
|
||||
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
|
||||
|
||||
for(index_t i = 0; i < 64; ++i)
|
||||
{
|
||||
printf("out %f, ", p_out_thread[i]);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
threadwise_nd_tensor_copy(out_10d_thread_desc,
|
||||
p_out_thread,
|
||||
out_10d_global_desc,
|
||||
p_out_global +
|
||||
out_k_h_w_n_global_desc.Get1dIndex(
|
||||
k_block_data_begin + k_thread_data_begin,
|
||||
ho_block_data_begin + ho_thread_data_begin,
|
||||
wo_block_data_begin + wo_thread_data_begin,
|
||||
n_block_data_begin + n_thread_data_begin),
|
||||
out_10d_thread_desc.GetLengths(),
|
||||
Number<OutThreadCopyDataPerWrite_N>{});
|
||||
});
|
||||
}
|
||||
};
|
||||
@@ -88,6 +88,7 @@ threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
#if 0 // replaced threadwise_nd_tensor_copy
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void threadwise_2d_tensor_copy(
|
||||
SrcDesc, Float* const __restrict__ p_src, DstDesc, Float* __restrict__ p_dst, SrcOpLengths)
|
||||
@@ -97,6 +98,7 @@ __device__ void threadwise_2d_tensor_copy(
|
||||
threadwise_2d_tensor_copy_reorder_by_get_dst_from_src(
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, dst_from_src_reorder);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class Float, class Desc, class IDim, class NShift>
|
||||
__device__ void threadwise_2d_tensor_shift_down(Desc, Float* __restrict__ p, IDim, NShift)
|
||||
|
||||
@@ -139,6 +139,7 @@ __device__ void threadwise_4d_tensor_copy_reorder_given_dst2src(SrcDesc,
|
||||
SrcDesc{}, p_src, DstDesc{}, p_dst, SrcOpLengths{}, MapDst2Src{}, f_copy);
|
||||
}
|
||||
|
||||
#if 0 // replaced threadwise_nd_tensor_copy
|
||||
template <class SrcData, class DstData, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
__device__ void threadwise_4d_tensor_copy(
|
||||
SrcDesc, const SrcData* __restrict__ p_src, DstDesc, DstData* __restrict__ p_dst, SrcOpLengths)
|
||||
@@ -210,6 +211,7 @@ __device__ void threadwise_4d_tensor_copy_v2(SrcDesc,
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
template <class SrcData,
|
||||
class DstData,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_6d_tensor_copy(SrcDesc,
|
||||
__device__ void threadwise_nd_tensor_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
@@ -12,268 +12,53 @@ __device__ void threadwise_6d_tensor_copy(SrcDesc,
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == 6 && DstDesc{}.GetDimension() == 6 &&
|
||||
SrcOpLengths::nDim == 6,
|
||||
"wrong! should be 6 dimension");
|
||||
constexpr index_t nDim = SrcOpLengths::GetSize();
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
static_assert(SrcDesc{}.GetDimension() == nDim && DstDesc{}.GetDimension() == nDim,
|
||||
"wrong! dimension not consistent");
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I5) == 1 && DstDesc{}.GetStride(I5) == 1,
|
||||
"wrong! only support stride5 == 1!\n");
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(src_desc, "src_desc");
|
||||
print_ConstantTensorDescriptor(dst_desc, "dst_desc");
|
||||
print_ConstantTensorDescriptor(ref_desc, "ref_desc");
|
||||
}
|
||||
#endif
|
||||
|
||||
static_assert(DataPerRead == 1 || (SrcDesc{}.GetStride(Number<nDim - 1>{}) == 1 &&
|
||||
DstDesc{}.GetStride(Number<nDim - 1>{}) == 1),
|
||||
"wrong! only support stride[nDim-1] == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I4) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I4) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
static_assert(
|
||||
SrcDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(Number<nDim - 2>{}) % DataPerRead == 0,
|
||||
"wrong! src and dst stride[nDim-2] should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L5 = SrcOpLengths{}.Get(I5);
|
||||
constexpr index_t L_Back = SrcOpLengths{}.Back();
|
||||
|
||||
static_assert(L5 % DataPerRead == 0, "wrong! L5 should be evenly divided by DataPerRead");
|
||||
static_assert(L_Back % DataPerRead == 0,
|
||||
"wrong! lengths[nDim-1] should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr index_t nloop_d5 = L5 / DataPerRead;
|
||||
constexpr index_t nRead = L_Back / DataPerRead;
|
||||
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
{
|
||||
for(index_t iloop_d5 = 0; iloop_d5 < nloop_d5; ++iloop_d5)
|
||||
{
|
||||
const index_t src_index = src_desc.Get1dIndex(
|
||||
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
|
||||
static_ford<decltype(ref_desc.GetLengths().PopBack())>{}([=](auto Ids) {
|
||||
static_for<0, nRead, 1>{}([=](auto IRead) {
|
||||
constexpr auto multi_id = decltype(Ids){}.PushBack(Number<IRead.Get() * DataPerRead>{});
|
||||
|
||||
const index_t dst_index = dst_desc.Get1dIndex(
|
||||
did0, did1, did2, did3, did4, iloop_d5 * DataPerRead);
|
||||
const index_t src_index = src_desc.Get1dIndex(multi_id);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_8d_tensor_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == 8 && DstDesc{}.GetDimension() == 8 &&
|
||||
SrcOpLengths::nDim == 8,
|
||||
"wrong! should be 8 dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I7) == 1 && DstDesc{}.GetStride(I7) == 1,
|
||||
"wrong! only support stride7 == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I6) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I6) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L7 = SrcOpLengths{}.Get(I7);
|
||||
|
||||
static_assert(L7 % DataPerRead == 0, "wrong! L7 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr index_t nloop_d7 = L7 / DataPerRead;
|
||||
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
{
|
||||
for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
|
||||
{
|
||||
for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
|
||||
{
|
||||
for(index_t iloop_d7 = 0; iloop_d7 < nloop_d7; ++iloop_d7)
|
||||
{
|
||||
const index_t src_index =
|
||||
src_desc.Get1dIndex(did0,
|
||||
did1,
|
||||
did2,
|
||||
did3,
|
||||
did4,
|
||||
did5,
|
||||
did6,
|
||||
iloop_d7 * DataPerRead);
|
||||
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did0,
|
||||
did1,
|
||||
did2,
|
||||
did3,
|
||||
did4,
|
||||
did5,
|
||||
did6,
|
||||
iloop_d7 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src + src_index));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// need to assume src and dst is aligned
|
||||
template <class Float, class SrcDesc, class DstDesc, class SrcOpLengths, index_t DataPerRead>
|
||||
__device__ void threadwise_10d_tensor_copy(SrcDesc,
|
||||
const Float* __restrict__ p_src,
|
||||
DstDesc,
|
||||
Float* __restrict__ p_dst,
|
||||
SrcOpLengths,
|
||||
Number<DataPerRead>)
|
||||
{
|
||||
using vector_t = typename vector_type<Float, DataPerRead>::MemoryType;
|
||||
|
||||
static_assert(SrcDesc{}.GetDimension() == 10 && DstDesc{}.GetDimension() == 10 &&
|
||||
SrcOpLengths::GetSize() == 10,
|
||||
"wrong! should be 10 dimension");
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
constexpr auto I4 = Number<4>{};
|
||||
constexpr auto I5 = Number<5>{};
|
||||
constexpr auto I6 = Number<6>{};
|
||||
constexpr auto I7 = Number<7>{};
|
||||
constexpr auto I8 = Number<8>{};
|
||||
constexpr auto I9 = Number<9>{};
|
||||
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
constexpr auto ref_desc = make_ConstantTensorDescriptor(SrcOpLengths{});
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I9) == 1 && DstDesc{}.GetStride(I9) == 1,
|
||||
"wrong! only support stride7 == 1!\n");
|
||||
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I8) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I8) % DataPerRead == 0,
|
||||
"wrong! src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
constexpr index_t L9 = SrcOpLengths{}.Get(I9);
|
||||
|
||||
static_assert(L9 % DataPerRead == 0, "wrong! L9 should be evenly divided by DataPerRead");
|
||||
|
||||
constexpr index_t nloop_d9 = L9 / DataPerRead;
|
||||
|
||||
#pragma unroll
|
||||
for(index_t did0 = 0; did0 < ref_desc.GetLength(I0); ++did0)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t did1 = 0; did1 < ref_desc.GetLength(I1); ++did1)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t did2 = 0; did2 < ref_desc.GetLength(I2); ++did2)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t did3 = 0; did3 < ref_desc.GetLength(I3); ++did3)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t did4 = 0; did4 < ref_desc.GetLength(I4); ++did4)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t did5 = 0; did5 < ref_desc.GetLength(I5); ++did5)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t did6 = 0; did6 < ref_desc.GetLength(I6); ++did6)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t did7 = 0; did7 < ref_desc.GetLength(I7); ++did7)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t did8 = 0; did8 < ref_desc.GetLength(I8); ++did8)
|
||||
{
|
||||
#pragma unroll
|
||||
for(index_t iloop_d9 = 0; iloop_d9 < nloop_d9; ++iloop_d9)
|
||||
{
|
||||
const index_t src_index =
|
||||
src_desc.Get1dIndex(did0,
|
||||
did1,
|
||||
did2,
|
||||
did3,
|
||||
did4,
|
||||
did5,
|
||||
did6,
|
||||
did7,
|
||||
did8,
|
||||
iloop_d9 * DataPerRead);
|
||||
|
||||
const index_t dst_index =
|
||||
dst_desc.Get1dIndex(did0,
|
||||
did1,
|
||||
did2,
|
||||
did3,
|
||||
did4,
|
||||
did5,
|
||||
did6,
|
||||
did7,
|
||||
did8,
|
||||
iloop_d9 * DataPerRead);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(p_dst + dst_index)) =
|
||||
*(reinterpret_cast<const vector_t*>(p_src +
|
||||
src_index));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
const index_t dst_index = dst_desc.Get1dIndex(multi_id);
|
||||
|
||||
*(reinterpret_cast<vector_t*>(&p_dst[dst_index])) =
|
||||
*(reinterpret_cast<const vector_t*>(&p_src[src_index]));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user