mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactor
This commit is contained in:
@@ -77,8 +77,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmRowThreadPerCluster = 4;
|
||||
constexpr unsigned GemmColumnThreadPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
@@ -120,7 +120,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
|
||||
#if 1
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
|
||||
#else
|
||||
#elif 0
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
#endif
|
||||
<GridSize,
|
||||
@@ -135,8 +135,8 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
GemmRowThreadPerCluster,
|
||||
GemmColumnThreadPerCluster,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
|
||||
@@ -76,8 +76,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
constexpr unsigned KPerThread = 1;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerClusterRow = 1;
|
||||
constexpr unsigned GemmThreadPerClusterColumn = 4;
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 1;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 32;
|
||||
#elif 0
|
||||
@@ -89,8 +89,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerClusterRow = 4;
|
||||
constexpr unsigned GemmThreadPerClusterColumn = 4;
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
@@ -102,8 +102,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmRowThreadPerCluster = 4;
|
||||
constexpr unsigned GemmColumnThreadPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 4;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 2;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 64;
|
||||
@@ -119,8 +119,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
|
||||
constexpr unsigned GemmRowThreadPerCluster = 8;
|
||||
constexpr unsigned GemmColumnThreadPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 8;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
@@ -171,8 +171,8 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
BPerThread,
|
||||
KPerThread,
|
||||
CPerThread,
|
||||
GemmRowThreadPerCluster,
|
||||
GemmColumnThreadPerCluster,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>
|
||||
<<<grid_dim, block_dim>>>(in_cnhw_desc,
|
||||
|
||||
@@ -449,5 +449,37 @@ struct Blockwise2dTensorCopy3
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] =
|
||||
p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
|
||||
nloop_d0 * dst_loop_stride)) =
|
||||
*(reinterpret_cast<Float2*>(p_src + mSrcMyThreadOffset +
|
||||
nloop_d0 * src_loop_stride));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
|
||||
nloop_d0 * dst_loop_stride)) =
|
||||
*(reinterpret_cast<Float4*>(p_src + mSrcMyThreadOffset +
|
||||
nloop_d0 * src_loop_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -20,8 +20,8 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned GemmThreadPerClusterRow,
|
||||
unsigned GemmThreadPerClusterColumn,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
@@ -192,8 +192,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
|
||||
@@ -20,8 +20,8 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned GemmThreadPerClusterRow,
|
||||
unsigned GemmThreadPerClusterColumn,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
@@ -192,8 +192,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
|
||||
@@ -20,8 +20,8 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned GemmThreadPerClusterRow,
|
||||
unsigned GemmThreadPerClusterColumn,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1>
|
||||
__global__ void
|
||||
@@ -159,8 +159,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
|
||||
@@ -20,8 +20,8 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned GemmRowThreadPerCluster,
|
||||
unsigned GemmColumnThreadPerCluster,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1>
|
||||
__global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline(
|
||||
@@ -175,8 +175,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmRowThreadPerCluster,
|
||||
GemmColumnThreadPerCluster,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
|
||||
Reference in New Issue
Block a user