mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 09:16:52 +00:00
refactor
This commit is contained in:
@@ -449,5 +449,37 @@ struct Blockwise2dTensorCopy3
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
if(get_thread_local_1d_id() < tail_d0 * thread_per_d1)
|
||||
{
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[mDstMyThreadOffset + nloop_d0 * dst_loop_stride] =
|
||||
p_src[mSrcMyThreadOffset + nloop_d0 * src_loop_stride];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset +
|
||||
nloop_d0 * dst_loop_stride)) =
|
||||
*(reinterpret_cast<Float2*>(p_src + mSrcMyThreadOffset +
|
||||
nloop_d0 * src_loop_stride));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset +
|
||||
nloop_d0 * dst_loop_stride)) =
|
||||
*(reinterpret_cast<Float4*>(p_src + mSrcMyThreadOffset +
|
||||
nloop_d0 * src_loop_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -20,8 +20,8 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned GemmThreadPerClusterRow,
|
||||
unsigned GemmThreadPerClusterColumn,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
@@ -192,8 +192,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
|
||||
@@ -20,8 +20,8 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned GemmThreadPerClusterRow,
|
||||
unsigned GemmThreadPerClusterColumn,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
@@ -192,8 +192,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
|
||||
@@ -20,8 +20,8 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned GemmThreadPerClusterRow,
|
||||
unsigned GemmThreadPerClusterColumn,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1>
|
||||
__global__ void
|
||||
@@ -159,8 +159,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
|
||||
@@ -20,8 +20,8 @@ template <unsigned GridSize,
|
||||
unsigned BPerThread,
|
||||
unsigned KPerThread,
|
||||
unsigned CPerThread,
|
||||
unsigned GemmRowThreadPerCluster,
|
||||
unsigned GemmColumnThreadPerCluster,
|
||||
unsigned GemmThreadPerColumnPerCluster,
|
||||
unsigned GemmThreadPerRowPerCluster,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1>
|
||||
__global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline(
|
||||
@@ -175,8 +175,8 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmRowThreadPerCluster,
|
||||
GemmColumnThreadPerCluster,
|
||||
GemmThreadPerColumnPerCluster,
|
||||
GemmThreadPerRowPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
|
||||
Reference in New Issue
Block a user