mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
cleaning up
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER
|
||||
#ifndef CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
#define CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
#include "common_header.hpp"
|
||||
#include "ConstantTensorDescriptor.hpp"
|
||||
@@ -23,8 +23,7 @@ template <index_t GridSize,
|
||||
index_t BPerBlock,
|
||||
index_t KPerBlock,
|
||||
index_t EPerBlock,
|
||||
index_t N1,
|
||||
index_t N2,
|
||||
index_t GemmNRepeat,
|
||||
index_t GemmMPerThreadSubC,
|
||||
index_t GemmNPerThreadSubC,
|
||||
index_t GemmMLevel0Cluster,
|
||||
@@ -47,17 +46,19 @@ template <index_t GridSize,
|
||||
class WeiBlockCopySrcAccessOrder,
|
||||
class WeiBlockCopyDstAccessOrder,
|
||||
index_t WeiBlockCopySrcDataPerRead_E,
|
||||
index_t WeiBlockCopyDstDataPerWrite_K,
|
||||
index_t OutThreadCopyDataPerAccess_W>
|
||||
index_t WeiBlockCopyDstDataPerWrite_K>
|
||||
struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
{
|
||||
__device__ void Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
__device__ void __launch_bounds__(BlockSize, 2)
|
||||
Run(const Float* const __restrict__ p_in_global,
|
||||
const Float* const __restrict__ p_wei_global,
|
||||
Float* const __restrict__ p_out_global) const
|
||||
{
|
||||
// this is a mess
|
||||
// TODO: find more elegent way of specifying (or calculating) performance parameters
|
||||
static_assert(N2 == GemmNPerThreadSubC, "wrong!");
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
|
||||
static_assert((N1 * N2 * BPerBlock) %
|
||||
(GemmNPerThreadSubC * GemmNLevel0Cluster * GemmNLevel1Cluster) ==
|
||||
0,
|
||||
@@ -464,4 +465,4 @@ struct GridwiseConvolutionImplicitGemm_v4r1_nchw_kcyx_nkhw_lds_double_buffer
|
||||
};
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
#endif // CK_GRIDWISE_CONVOLUTION_IMPLICIT_GEMM_V4R1_NCHW_KCYX_NKHW_LDS_DOUBLE_BUFFER_HPP
|
||||
|
||||
@@ -54,11 +54,6 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
wei_kcyx_device_buf.ToDevice(wei_kcyx.mData.data());
|
||||
out_nkhw_device_buf.ToDevice(out_nkhw.mData.data());
|
||||
|
||||
constexpr index_t N1 = 2;
|
||||
constexpr index_t N2 = 4;
|
||||
|
||||
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
|
||||
|
||||
#if 1
|
||||
// each thread hold 64 data
|
||||
constexpr index_t BlockSize = 256;
|
||||
@@ -67,6 +62,8 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t KPerBlock = 128;
|
||||
constexpr index_t EPerBlock = 8;
|
||||
|
||||
constexpr index_t GemmNRepeat = 2;
|
||||
|
||||
constexpr index_t GemmMPerThreadSubC = 4;
|
||||
constexpr index_t GemmNPerThreadSubC = 4;
|
||||
constexpr index_t GemmMLevel0Cluster = 4;
|
||||
@@ -168,6 +165,11 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
constexpr index_t WeiBlockCopyDstDataPerWrite_K = 1;
|
||||
#endif
|
||||
|
||||
constexpr index_t N1 = GemmNRepeat;
|
||||
constexpr index_t N2 = GemmNPerThreadSubC;
|
||||
|
||||
constexpr index_t B = (N * Ho * Wo) / (N1 * N2);
|
||||
|
||||
constexpr index_t GridSize =
|
||||
((B + BPerBlock - 1) / BPerBlock) * ((K + KPerBlock - 1) / KPerBlock);
|
||||
|
||||
@@ -192,8 +194,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
BPerBlock,
|
||||
KPerBlock,
|
||||
EPerBlock,
|
||||
N1,
|
||||
N2,
|
||||
GemmNRepeat,
|
||||
GemmMPerThreadSubC,
|
||||
GemmNPerThreadSubC,
|
||||
GemmMLevel0Cluster,
|
||||
@@ -216,8 +217,7 @@ void device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(InDesc,
|
||||
WeiBlockCopySrcAccessOrder,
|
||||
WeiBlockCopyDstAccessOrder,
|
||||
WeiBlockCopySrcDataPerRead_E,
|
||||
WeiBlockCopyDstDataPerWrite_K,
|
||||
OutThreadCopyDataPerAccess_W>{};
|
||||
WeiBlockCopyDstDataPerWrite_K>{};
|
||||
|
||||
float time = launch_kernel(run_gridwise_convolution_kernel<decltype(gridwise_conv), T>,
|
||||
dim3(GridSize),
|
||||
|
||||
@@ -379,7 +379,7 @@ int main(int argc, char* argv[])
|
||||
#elif 0
|
||||
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw(
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
#elif 0
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v4r1_nchw_kcyx_nkhw(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
|
||||
Reference in New Issue
Block a user