mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
Deprecate static kernel (#42)
* deprecate static kernels
[ROCm/composable_kernel commit: 81c942cd7e]
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
#ifndef CONV_COMMON_HPP
|
||||
#define CONV_COMMON_HPP
|
||||
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "dynamic_tensor_descriptor.hpp"
|
||||
|
||||
enum ConvTensorLayout
|
||||
@@ -13,53 +12,6 @@ enum ConvTensorLayout
|
||||
NHWCc
|
||||
};
|
||||
|
||||
template <class InDesc,
|
||||
class WeiDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LeftPads,
|
||||
class RightPads>
|
||||
constexpr auto get_convolution_output_default_4d_tensor_descriptor(
|
||||
InDesc, WeiDesc, ConvStrides, ConvDilations, LeftPads, RightPads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto in_desc = InDesc{};
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
static_assert(in_desc.GetNumOfDimension() == 4, "input nDim is not 4");
|
||||
static_assert(wei_desc.GetNumOfDimension() == 4, "weight nDim is not 4");
|
||||
static_assert(in_desc.GetLength(I1) == wei_desc.GetLength(I1),
|
||||
"input & weight dimension not consistent");
|
||||
|
||||
constexpr index_t N = in_desc.GetLength(I0);
|
||||
constexpr index_t Hi = in_desc.GetLength(I2);
|
||||
constexpr index_t Wi = in_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t K = wei_desc.GetLength(I0);
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t LeftPadH = LeftPads{}.Get(I0);
|
||||
constexpr index_t LeftPadW = LeftPads{}.Get(I1);
|
||||
|
||||
constexpr index_t RightPadH = RightPads{}.Get(I0);
|
||||
constexpr index_t RightPadW = RightPads{}.Get(I1);
|
||||
|
||||
constexpr index_t YEff = (Y - 1) * ConvDilations{}[0] + 1;
|
||||
constexpr index_t XEff = (X - 1) * ConvDilations{}[1] + 1;
|
||||
|
||||
constexpr index_t Ho = (Hi + LeftPadH + RightPadH - YEff) / ConvStrides{}[0] + 1;
|
||||
constexpr index_t Wo = (Wi + LeftPadW + RightPadW - XEff) / ConvStrides{}[1] + 1;
|
||||
|
||||
return make_native_tensor_descriptor_packed(Sequence<N, K, Ho, Wo>{});
|
||||
}
|
||||
|
||||
template <typename... InDesc,
|
||||
typename... WeiDesc,
|
||||
typename ConvStrides,
|
||||
@@ -131,30 +83,4 @@ calculate_convolution_flops(const InDesc& in_desc, const WeiDesc& wei_desc, cons
|
||||
return std::size_t(2) * N * K * Ho * Wo * C * Y * X;
|
||||
}
|
||||
|
||||
template <class Float, class InDesc, class WeiDesc, class OutDesc>
|
||||
constexpr std::size_t calculate_convolution_memory_size(Float, InDesc, WeiDesc, OutDesc)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr auto wei_desc = WeiDesc{};
|
||||
constexpr auto out_desc = OutDesc{};
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr index_t N = out_desc.GetLength(I0);
|
||||
constexpr index_t K = out_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_desc.GetLength(I3);
|
||||
|
||||
constexpr index_t C = wei_desc.GetLength(I1);
|
||||
constexpr index_t Y = wei_desc.GetLength(I2);
|
||||
constexpr index_t X = wei_desc.GetLength(I3);
|
||||
|
||||
return sizeof(Float) *
|
||||
(InDesc::GetElementSpace() + WeiDesc::GetElementSpace() + OutDesc::GetElementSpace());
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,221 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v1r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<1, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#elif 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<2, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<2, 4>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 4;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 4;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = C * Y * X;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv_bwd_data = GridwiseConvolutionBackwardDataImplicitGemm_v1r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>;
|
||||
|
||||
for(index_t i = 0; i < 1; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -1,171 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v1r2_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// BlockSize = 256, each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t BPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 32;
|
||||
constexpr index_t KPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using OutBlockCopySubLengths_K_B_N0 = Sequence<2, 1, 4>;
|
||||
using OutBlockCopyClusterLengths_K_B_N0 = Sequence<8, 32, 1>;
|
||||
|
||||
constexpr index_t OutBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t OutBlockCopyDstDataPerWrite_N0 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_K_E_C0 = Sequence<2, 4, 1>;
|
||||
using WeiBlockCopyClusterLengths_K_E_C0 = Sequence<8, 8, 4>;
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_C0 = 1;
|
||||
|
||||
constexpr index_t InThreadCopyDstDataPerWrite_B = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t C0 = GemmMPerThread;
|
||||
constexpr index_t N0 = GemmNPerThread;
|
||||
|
||||
constexpr index_t C1 = C / C0;
|
||||
constexpr index_t N1 = N / N0;
|
||||
|
||||
constexpr index_t E = C1 * Y * X;
|
||||
constexpr index_t B = (N1 * Ho * Wo);
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((E + EPerBlock - 1) / EPerBlock) * ((B + BPerBlock - 1) / BPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv_bwd_data =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v1r2_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
EPerBlock,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
OutBlockCopySubLengths_K_B_N0,
|
||||
OutBlockCopyClusterLengths_K_B_N0,
|
||||
OutBlockCopySrcDataPerRead_B,
|
||||
OutBlockCopyDstDataPerWrite_N0,
|
||||
WeiBlockCopySubLengths_K_E_C0,
|
||||
WeiBlockCopyClusterLengths_K_E_C0,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_C0,
|
||||
InThreadCopyDstDataPerWrite_B>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv_bwd_data,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -1,267 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
// GemmABlockCopySrcDataPerRead_GemmM = 4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 8>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<16, 16>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<8, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmN = 1;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
using GridwiseConvBwdData =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v4r1_nchw_kcyx_nkhw<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmN,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>;
|
||||
|
||||
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
|
||||
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
|
||||
constexpr index_t gemm_k = gemm_sizes.At(2);
|
||||
constexpr bool is_gemm_not_empty = gemm_k > 0;
|
||||
|
||||
// only compile and run if GEMM is no empty
|
||||
static_if<is_gemm_not_empty>{}([&](auto fwd) {
|
||||
launch_kernel(run_gridwise_operation<GridwiseConvBwdData,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
decltype(gemm_id)>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nkhw_device_buf.GetDeviceBuffer()),
|
||||
fwd(gemm_id));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
in_nchw_device_buf.FromDevice(in_nchw.mData.data());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -1,266 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
namespace launcher {
|
||||
|
||||
using namespace ck;
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename InLeftPads,
|
||||
typename InRightPads>
|
||||
void device_convolution_backward_data_implicit_gemm_v5r1_nhwc_kyxc_nhwk(InDesc in_nchw_desc,
|
||||
Tensor<T>& in_nchw,
|
||||
WeiDesc wei_kcyx_desc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc out_nkhw_desc,
|
||||
const Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
std::size_t nrepeat)
|
||||
{
|
||||
constexpr index_t N = out_nkhw_desc.GetLengths()[0];
|
||||
constexpr index_t K = out_nkhw_desc.GetLengths()[1];
|
||||
constexpr index_t C = wei_kcyx_desc.GetLengths()[1];
|
||||
|
||||
constexpr index_t Hi = in_nchw_desc.GetLengths()[2];
|
||||
constexpr index_t Wi = in_nchw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLengths()[2];
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t Y = wei_kcyx_desc.GetLengths()[2];
|
||||
constexpr index_t X = wei_kcyx_desc.GetLengths()[3];
|
||||
|
||||
constexpr index_t ConvStrideH = ConvStrides{}[0];
|
||||
constexpr index_t ConvStrideW = ConvStrides{}[1];
|
||||
|
||||
constexpr index_t ConvDilationH = ConvDilations{}[0];
|
||||
constexpr index_t ConvDilationW = ConvDilations{}[1];
|
||||
|
||||
constexpr auto in_nhwc_desc = make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
|
||||
constexpr auto wei_kyxc_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
|
||||
constexpr auto out_nhwk_desc = make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
|
||||
|
||||
Tensor<float> in_nhwc(make_HostTensorDescriptor(in_nhwc_desc));
|
||||
Tensor<float> wei_kyxc(make_HostTensorDescriptor(wei_kyxc_desc));
|
||||
Tensor<float> out_nhwk(make_HostTensorDescriptor(out_nhwk_desc));
|
||||
|
||||
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x);
|
||||
};
|
||||
|
||||
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
|
||||
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
|
||||
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
|
||||
|
||||
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
|
||||
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
|
||||
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
|
||||
|
||||
#if 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 8;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 1, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 16;
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadM = 4;
|
||||
constexpr index_t GemmThreadGemmDataPerReadN = 4;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 2, 4>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM = Sequence<1, 1, 8, 32>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmM = 4;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 4;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 8, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN = Sequence<1, 1, 2, 128>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK2 = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmN1 = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
constexpr index_t GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
constexpr index_t YTilda = ConvStrideH / GcdStrideDilationH;
|
||||
constexpr index_t XTilda = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
constexpr index_t YDot = math::integer_divide_ceil(Y, YTilda);
|
||||
constexpr index_t XDot = math::integer_divide_ceil(X, XTilda);
|
||||
|
||||
constexpr index_t HTilda = Ho + math::integer_divide_ceil(ConvDilationH * (Y - 1), ConvStrideH);
|
||||
constexpr index_t WTilda = Wo + math::integer_divide_ceil(ConvDilationW * (X - 1), ConvStrideW);
|
||||
|
||||
constexpr index_t HTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[0] - ConvDilationH * (YTilda - 1)), ConvStrides{}[0]);
|
||||
constexpr index_t WTildaLeft = math::integer_divide_floor(
|
||||
math::max(0, InLeftPads{}[1] - ConvDilationW * (XTilda - 1)), ConvStrides{}[1]);
|
||||
|
||||
constexpr index_t HTildaRight = math::min(
|
||||
HTilda, math::integer_divide_ceil(InLeftPads{}[0] + Hi - 1, ConvStrides{}[0]) + 1);
|
||||
constexpr index_t WTildaRight = math::min(
|
||||
WTilda, math::integer_divide_ceil(InLeftPads{}[1] + Wi - 1, ConvStrides{}[1]) + 1);
|
||||
|
||||
constexpr index_t HTildaSlice = HTildaRight - HTildaLeft;
|
||||
constexpr index_t WTildaSlice = WTildaRight - WTildaLeft;
|
||||
|
||||
constexpr index_t GemmM = C;
|
||||
constexpr index_t GemmN = N * HTildaSlice * WTildaSlice;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t i = 0; i < nrepeat; ++i)
|
||||
{
|
||||
using GridwiseConvBwdData =
|
||||
GridwiseConvolutionBackwardDataImplicitGemm_v5r1_nhwc_kyxc_nhwk<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nhwc_desc),
|
||||
decltype(wei_kyxc_desc),
|
||||
decltype(out_nhwk_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmThreadGemmDataPerReadM,
|
||||
GemmThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmM,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK0_GemmK1_GemmK2_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmK2,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmN1>;
|
||||
|
||||
static_for<0, GridwiseConvBwdData::GetNumberOfGemm(), 1>{}([&](auto gemm_id) {
|
||||
constexpr auto gemm_sizes = GridwiseConvBwdData::GetGemmSize(gemm_id);
|
||||
constexpr index_t gemm_k2 = gemm_sizes[Number<4>{}];
|
||||
constexpr bool is_gemm_not_empty = gemm_k2 > 0;
|
||||
|
||||
// only compile and run if GEMM is no empty
|
||||
static_if<is_gemm_not_empty>{}([&](auto fwd) {
|
||||
launch_kernel(run_gridwise_operation<GridwiseConvBwdData,
|
||||
T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
const T* const __restrict__,
|
||||
decltype(gemm_id)>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<T*>(in_nhwc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(wei_kyxc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<T*>(out_nhwk_device_buf.GetDeviceBuffer()),
|
||||
fwd(gemm_id));
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
in_nhwc_device_buf.FromDevice(in_nhwc.mData.data());
|
||||
|
||||
auto f_nhwc2nchw = [&](auto n, auto c, auto hi, auto wi) {
|
||||
in_nchw(n, c, hi, wi) = in_nhwc(n, hi, wi, c);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nhwc2nchw, N, C, Hi, Wi)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace launcher
|
||||
@@ -1,849 +0,0 @@
|
||||
#pragma once
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer.hpp"
|
||||
|
||||
template <typename T,
|
||||
typename InDesc,
|
||||
typename WeiDesc,
|
||||
typename OutDesc,
|
||||
typename ConvStrides,
|
||||
typename ConvDilations,
|
||||
typename LeftPads,
|
||||
typename RightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
std::cout << "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw" << std::endl;
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto in_nchw_desc =
|
||||
make_native_tensor_descriptor(InDesc::GetLengths(), InDesc::GetStrides());
|
||||
constexpr auto wei_kcyx_desc =
|
||||
make_native_tensor_descriptor(WeiDesc::GetLengths(), WeiDesc::GetStrides());
|
||||
constexpr auto out_nkhw_desc =
|
||||
make_native_tensor_descriptor(OutDesc::GetLengths(), OutDesc::GetStrides());
|
||||
|
||||
constexpr index_t N = out_nkhw_desc.GetLength(I0);
|
||||
constexpr index_t K = out_nkhw_desc.GetLength(I1);
|
||||
constexpr index_t Ho = out_nkhw_desc.GetLength(I2);
|
||||
constexpr index_t Wo = out_nkhw_desc.GetLength(I3);
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
DeviceMem in_nchw_device_buf(data_sz * in_nchw.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kcyx_device_buf(data_sz * wei_kcyx.mDesc.GetElementSpace());
|
||||
DeviceMem out_nkhw_device_buf(data_sz * out_nkhw.mDesc.GetElementSpace());
|
||||
|
||||
in_nchw_device_buf.ToDevice(in_nchw.mData.data());
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
#if 0
|
||||
// cdata = 64, BlockSize = 256, 64x256x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 32;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 16;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 32, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 256, 128x128x4
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 128>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 1
|
||||
// cdata = 64, BlockSize = 256, 128x128x16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 4, BlockSize = 256, 128x128x16
|
||||
// for 1x1
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<4, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 64x128x4
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 64x128x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 64x128x16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 128x64x4
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 2, 8, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 128x64x8
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 8, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 2;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 128, 128x64x16
|
||||
constexpr index_t BlockSize = 128;
|
||||
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 16;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 8;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<16, 1, 8, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 4>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 4;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 64x64x8
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 1, 8, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 2>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 32, 32x64x3
|
||||
constexpr index_t BlockSize = 32;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 3;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 2;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 32x128x3
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 3;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 2>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 16, 2>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 2;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 64x64x3
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 8;
|
||||
constexpr index_t EPerBlock = 3;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<3, 1, 1, 1>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<1, 2, 8, 4>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 1;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<3, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<1, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 1;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 32x128x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 64, BlockSize = 64, 32x128x8
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t KPerBlock = 32;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 4;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 1;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 4;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<2, 2, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<4, 1, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<4, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<2, 32>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 4;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#elif 0
|
||||
// cdata = 32, BlockSize = 256, 64x128x8
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t KPerBlock = 64;
|
||||
constexpr index_t BPerBlock = 16;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 4;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
constexpr index_t GemmNLevel0Cluster = 4;
|
||||
constexpr index_t GemmMLevel1Cluster = 4;
|
||||
constexpr index_t GemmNLevel1Cluster = 4;
|
||||
|
||||
constexpr index_t GemmDataPerReadA = 2;
|
||||
constexpr index_t GemmDataPerReadB = 4;
|
||||
|
||||
using InBlockCopySubLengths_E_N1_B_N2 = Sequence<1, 1, 1, 4>;
|
||||
using InBlockCopyClusterLengths_E_N1_B_N2 = Sequence<8, 2, 16, 1>;
|
||||
using InBlockCopyThreadClusterArrangeOrder = Sequence<0, 1, 3, 2>; // [E, N1, N2, B]
|
||||
using InBlockCopySrcAccessOrder = Sequence<0, 2, 1, 3>; // [E, B, N1, N2]
|
||||
using InBlockCopyDstAccessOrder = Sequence<0, 1, 2, 3>; // [E, N1, B, N2]
|
||||
|
||||
constexpr index_t InBlockCopySrcDataPerRead_B = 1;
|
||||
constexpr index_t InBlockCopyDstDataPerWrite_N2 = 4;
|
||||
|
||||
using WeiBlockCopySubLengths_E_K = Sequence<2, 1>;
|
||||
using WeiBlockCopyClusterLengths_E_K = Sequence<4, 64>;
|
||||
using WeiBlockCopyThreadClusterArrangeOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopySrcAccessOrder = Sequence<1, 0>; // [K, E]
|
||||
using WeiBlockCopyDstAccessOrder = Sequence<0, 1>; // [E, K]
|
||||
|
||||
constexpr index_t WeiBlockCopySrcDataPerRead_E = 2;
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThread;
|
||||
|
||||
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv =
|
||||
GridwiseConvolutionForwardImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
T,
|
||||
T,
|
||||
decltype(in_nchw_desc),
|
||||
decltype(wei_kcyx_desc),
|
||||
decltype(out_nkhw_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LeftPads,
|
||||
RightPads,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
GemmNRepeat,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
GemmDataPerReadA,
|
||||
GemmDataPerReadB,
|
||||
InBlockCopySubLengths_E_N1_B_N2,
|
||||
InBlockCopyClusterLengths_E_N1_B_N2,
|
||||
InBlockCopyThreadClusterArrangeOrder,
|
||||
InBlockCopySrcAccessOrder,
|
||||
InBlockCopyDstAccessOrder,
|
||||
InBlockCopySrcDataPerRead_B,
|
||||
InBlockCopyDstDataPerWrite_N2,
|
||||
WeiBlockCopySubLengths_E_K,
|
||||
WeiBlockCopyClusterLengths_E_K,
|
||||
WeiBlockCopyThreadClusterArrangeOrder,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv,
|
||||
const TDevice* const __restrict__,
|
||||
const TDevice* const __restrict__,
|
||||
TDevice* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<TDevice*>(in_nchw_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(wei_kcyx_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(out_nkhw_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
out_nkhw_device_buf.FromDevice(out_nkhw.mData.data());
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,207 +0,0 @@
|
||||
#include <unistd.h>
|
||||
#include "device.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "gridwise_operation_wrapper.hpp"
|
||||
#include "gridwise_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp"
|
||||
|
||||
template <class T,
|
||||
class InDesc,
|
||||
class WeiDesc,
|
||||
class OutDesc,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class InLeftPads,
|
||||
class InRightPads>
|
||||
void device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(InDesc,
|
||||
const Tensor<T>& in_nchw,
|
||||
WeiDesc,
|
||||
const Tensor<T>& wei_kcyx,
|
||||
OutDesc,
|
||||
Tensor<T>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
ck::index_t nrepeat)
|
||||
{
|
||||
std::cout << "device_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk" << std::endl;
|
||||
|
||||
using namespace ck;
|
||||
|
||||
using TDevice = typename conditional<is_same<half_float::half, T>::value, half_t, T>::type;
|
||||
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
constexpr auto I2 = Number<2>{};
|
||||
constexpr auto I3 = Number<3>{};
|
||||
|
||||
constexpr auto N = OutDesc::GetLengths()[I0];
|
||||
constexpr auto K = OutDesc::GetLengths()[I1];
|
||||
constexpr auto C = WeiDesc::GetLengths()[I1];
|
||||
|
||||
constexpr auto Hi = InDesc::GetLengths()[I2];
|
||||
constexpr auto Wi = InDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Ho = OutDesc::GetLengths()[I2];
|
||||
constexpr auto Wo = OutDesc::GetLengths()[I3];
|
||||
|
||||
constexpr auto Y = WeiDesc::GetLengths()[I2];
|
||||
constexpr auto X = WeiDesc::GetLengths()[I3];
|
||||
|
||||
// compile-time variables
|
||||
constexpr auto in_n_hi_wi_c_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{});
|
||||
constexpr auto wei_k_y_x_c_desc = make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{});
|
||||
constexpr auto out_n_ho_wo_k_desc =
|
||||
make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{});
|
||||
|
||||
Tensor<float> in_nhwc(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Hi, Wi, C>{})));
|
||||
Tensor<float> wei_kyxc(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<K, Y, X, C>{})));
|
||||
Tensor<float> out_nhwk(
|
||||
make_HostTensorDescriptor(make_native_tensor_descriptor_packed(Sequence<N, Ho, Wo, K>{})));
|
||||
|
||||
auto f_nchw2nhwc = [&](auto n, auto hi, auto wi, auto c) {
|
||||
in_nhwc(n, hi, wi, c) = in_nchw(n, c, hi, wi);
|
||||
};
|
||||
|
||||
auto f_kcyx2kyxc = [&](auto k, auto y, auto x, auto c) {
|
||||
wei_kyxc(k, y, x, c) = wei_kcyx(k, c, y, x);
|
||||
};
|
||||
|
||||
auto f_nkhw2nhwk = [&](auto n, auto ho, auto wo, auto k) {
|
||||
out_nhwk(n, ho, wo, k) = out_nkhw(n, k, ho, wo);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nchw2nhwc, N, Hi, Wi, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_kcyx2kyxc, K, Y, X, C)(std::thread::hardware_concurrency());
|
||||
make_ParallelTensorFunctor(f_nkhw2nhwk, N, Ho, Wo, K)(std::thread::hardware_concurrency());
|
||||
|
||||
std::size_t data_sz = sizeof(T);
|
||||
|
||||
DeviceMem in_nhwc_device_buf(data_sz * in_nhwc.mDesc.GetElementSpace());
|
||||
DeviceMem wei_kyxc_device_buf(data_sz * wei_kyxc.mDesc.GetElementSpace());
|
||||
DeviceMem out_nhwk_device_buf(data_sz * out_nhwk.mDesc.GetElementSpace());
|
||||
|
||||
in_nhwc_device_buf.ToDevice(in_nhwc.mData.data());
|
||||
wei_kyxc_device_buf.ToDevice(wei_kyxc.mData.data());
|
||||
out_nhwk_device_buf.ToDevice(out_nhwk.mData.data());
|
||||
|
||||
#if 1
|
||||
// cdata = 16, BlockSize = 64, 16x64x4
|
||||
constexpr index_t BlockSize = 64;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 16;
|
||||
constexpr index_t GemmNPerBlock = 64;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerThread = 2;
|
||||
constexpr index_t GemmNPerThread = 2;
|
||||
constexpr index_t GemmKPerThread = 1;
|
||||
|
||||
constexpr index_t GemmMLevel0Cluster = 2;
|
||||
constexpr index_t GemmNLevel0Cluster = 2;
|
||||
constexpr index_t GemmMLevel1Cluster = 2;
|
||||
constexpr index_t GemmNLevel1Cluster = 8;
|
||||
|
||||
constexpr index_t ThreadGemmDataPerReadM = 2;
|
||||
constexpr index_t ThreadGemmDataPerReadN = 2;
|
||||
|
||||
using GemmABlockCopyThreadSliceLengths_GemmK_GemmM = Sequence<1, 1>;
|
||||
using GemmABlockCopyThreadClusterLengths_GemmK_GemmM = Sequence<4, 16>;
|
||||
|
||||
constexpr index_t GemmABlockCopySrcDataPerRead_GemmK = 1;
|
||||
constexpr index_t GemmABlockCopyDstDataPerWrite_GemmM = 1;
|
||||
|
||||
using GemmBBlockCopyThreadSliceLengths_GemmK_GemmN = Sequence<4, 1>;
|
||||
using GemmBBlockCopyThreadClusterLengths_GemmK_GemmN = Sequence<1, 64>;
|
||||
|
||||
constexpr index_t GemmBBlockCopySrcDataPerRead_GemmK = 4;
|
||||
constexpr index_t GemmBBlockCopyDstDataPerWrite_GemmN = 1;
|
||||
|
||||
constexpr index_t GemmCThreadCopyDstDataPerWrite_GemmM1 = 2;
|
||||
#endif
|
||||
|
||||
constexpr index_t GemmM = K;
|
||||
constexpr index_t GemmN = N * Ho * Wo;
|
||||
|
||||
constexpr index_t GridSize = math::integer_divide_ceil(GemmM, GemmMPerBlock) *
|
||||
math::integer_divide_ceil(GemmN, GemmNPerBlock);
|
||||
|
||||
printf("%s: BlockSize %u, GridSize %u \n", __func__, BlockSize, GridSize);
|
||||
|
||||
using gridwise_conv = GridwiseConvolutionForwardImplicitGemm_v4r4_nhwc_kyxc_nhwk<
|
||||
GridSize,
|
||||
BlockSize,
|
||||
TDevice,
|
||||
TDevice,
|
||||
decltype(in_n_hi_wi_c_desc),
|
||||
decltype(wei_k_y_x_c_desc),
|
||||
decltype(out_n_ho_wo_k_desc),
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
InLeftPads,
|
||||
InRightPads,
|
||||
GemmMPerBlock,
|
||||
GemmNPerBlock,
|
||||
GemmKPerBlock,
|
||||
GemmMPerThread,
|
||||
GemmNPerThread,
|
||||
GemmKPerThread,
|
||||
GemmMLevel0Cluster,
|
||||
GemmNLevel0Cluster,
|
||||
GemmMLevel1Cluster,
|
||||
GemmNLevel1Cluster,
|
||||
ThreadGemmDataPerReadM,
|
||||
ThreadGemmDataPerReadN,
|
||||
GemmABlockCopyThreadSliceLengths_GemmK_GemmM,
|
||||
GemmABlockCopyThreadClusterLengths_GemmK_GemmM,
|
||||
GemmABlockCopySrcDataPerRead_GemmK,
|
||||
GemmABlockCopyDstDataPerWrite_GemmM,
|
||||
GemmBBlockCopyThreadSliceLengths_GemmK_GemmN,
|
||||
GemmBBlockCopyThreadClusterLengths_GemmK_GemmN,
|
||||
GemmBBlockCopySrcDataPerRead_GemmK,
|
||||
GemmBBlockCopyDstDataPerWrite_GemmN,
|
||||
GemmCThreadCopyDstDataPerWrite_GemmM1>;
|
||||
|
||||
for(index_t i = 0; i < 5; ++i)
|
||||
{
|
||||
std::cout << "Start running " << nrepeat << " times..." << std::endl;
|
||||
|
||||
KernelTimer timer;
|
||||
timer.Start();
|
||||
|
||||
for(index_t j = 0; j < nrepeat; ++j)
|
||||
{
|
||||
launch_kernel(run_gridwise_operation<gridwise_conv,
|
||||
const TDevice* const __restrict__,
|
||||
const TDevice* const __restrict__,
|
||||
TDevice* const __restrict__>,
|
||||
dim3(GridSize),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
0,
|
||||
static_cast<TDevice*>(in_nhwc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(wei_kyxc_device_buf.GetDeviceBuffer()),
|
||||
static_cast<TDevice*>(out_nhwk_device_buf.GetDeviceBuffer()));
|
||||
}
|
||||
|
||||
timer.End();
|
||||
|
||||
float ave_time = timer.GetElapsedTime() / nrepeat;
|
||||
|
||||
float perf = (float)calculate_convolution_flops(InDesc{}, WeiDesc{}, OutDesc{}) /
|
||||
(std::size_t(1000) * 1000 * 1000) / ave_time;
|
||||
|
||||
std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl;
|
||||
}
|
||||
|
||||
out_nhwk_device_buf.FromDevice(out_nhwk.mData.data());
|
||||
|
||||
auto f_nhwk2nkhw = [&](auto n, auto k, auto ho, auto wo) {
|
||||
out_nkhw(n, k, ho, wo) = out_nhwk(n, ho, wo, k);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f_nhwk2nkhw, N, K, Ho, Wo)(std::thread::hardware_concurrency());
|
||||
}
|
||||
@@ -257,7 +257,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
||||
GemmKPerBlock,
|
||||
GemmMPerWave,
|
||||
GemmNPerWave,
|
||||
GemmK1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1,
|
||||
|
||||
@@ -57,7 +57,35 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nh
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8]
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 64;
|
||||
constexpr index_t GemmNPerWave = 64;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 1;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
|
||||
@@ -56,7 +56,63 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -111,34 +167,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nh
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#elif 0
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 4;
|
||||
#endif
|
||||
|
||||
|
||||
@@ -56,7 +56,63 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
const auto out_n_ho_wo_k_desc =
|
||||
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths);
|
||||
|
||||
#if 0
|
||||
#if 1
|
||||
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 256;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 4;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [256, 256, 4, 8] for fp16
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
@@ -139,34 +195,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
||||
constexpr index_t BlockSize = 256;
|
||||
|
||||
constexpr index_t GemmMPerBlock = 128;
|
||||
constexpr index_t GemmNPerBlock = 128;
|
||||
constexpr index_t GemmKPerBlock = 4;
|
||||
|
||||
constexpr index_t GemmMPerWave = 32;
|
||||
constexpr index_t GemmNPerWave = 32;
|
||||
constexpr index_t GemmK1 = 4;
|
||||
|
||||
constexpr index_t MRepeat = 2;
|
||||
constexpr index_t NRepeat = 2;
|
||||
|
||||
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>;
|
||||
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
||||
|
||||
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4;
|
||||
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
||||
|
||||
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
||||
#elif 1
|
||||
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
||||
|
||||
@@ -215,7 +215,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r5r2_nchw_kcyx_nkhw(
|
||||
Sequence<4, 3, 2, 0, 1>, // BBlockTransferSrcAccessOrder
|
||||
GemmBBlockTransferSrcVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
GemmBBlockTransferDstVectorTensorLengths_GK0_GN0_GN10_GN11_GK1,
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<0, 1, 2, 3, 4>, // BBlockTransferSrcVectorTensorContiguousDimOrder
|
||||
Sequence<3, 4, 5, 0, 1, 2>, // CThreadTransferSrcDstAccessOrder
|
||||
5, // CThreadTransferSrcDstVectorDim
|
||||
GemmCThreadTransferDstScalarPerVector_BN1,
|
||||
|
||||
@@ -1,23 +1,6 @@
|
||||
#pragma once
|
||||
#include "host_tensor.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
|
||||
template <typename TensorDesc, std::size_t... Is>
|
||||
auto make_HostTensorDescriptor_impl(TensorDesc, std::integer_sequence<std::size_t, Is...>)
|
||||
{
|
||||
std::initializer_list<std::size_t> lengths = {TensorDesc::GetLengths()[Is]...};
|
||||
std::initializer_list<std::size_t> strides = {TensorDesc::GetStrides()[Is]...};
|
||||
|
||||
return HostTensorDescriptor(lengths, strides);
|
||||
}
|
||||
|
||||
template <typename TensorDesc>
|
||||
auto make_HostTensorDescriptor(TensorDesc)
|
||||
{
|
||||
return make_HostTensorDescriptor_impl(
|
||||
TensorDesc{}, std::make_integer_sequence<std::size_t, TensorDesc::GetNumOfDimension()>{});
|
||||
}
|
||||
|
||||
template <typename TensorDesc>
|
||||
void ostream_tensor_descriptor(TensorDesc, std::ostream& os = std::cout)
|
||||
|
||||
Reference in New Issue
Block a user