mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
add another version of blockwise 2d copy, refactor
This commit is contained in:
@@ -86,6 +86,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 4;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 64;
|
||||
#endif
|
||||
|
||||
@@ -137,7 +140,9 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1,
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>
|
||||
WeiBlockCopyThreadPerDim1,
|
||||
InBlockCopyDataPerRead,
|
||||
WeiBlockCopyDataPerRead>
|
||||
<<<grid_dim, block_dim>>>(in_cnhw_desc,
|
||||
static_cast<T*>(in_cnhw_device_buf.GetDeviceBuffer()),
|
||||
wei_csrk_desc,
|
||||
|
||||
@@ -162,9 +162,9 @@ blockwise_2d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
struct blockwise_2d_tensor_copy_1
|
||||
struct Blockwise2dTensorCopy1
|
||||
{
|
||||
__device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
__device__ void Run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto dst_from_src_reorder = Sequence<0, 1>{};
|
||||
|
||||
@@ -173,6 +173,8 @@ struct blockwise_2d_tensor_copy_1
|
||||
}
|
||||
};
|
||||
|
||||
// need to be aligned to float4 and float2
|
||||
// stride1 need to be 1 for both source and destination
|
||||
template <unsigned BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
@@ -180,21 +182,27 @@ template <unsigned BlockSize,
|
||||
class SrcOpLengths,
|
||||
unsigned ThreadPerDim0,
|
||||
unsigned ThreadPerDim1>
|
||||
struct blockwise_2d_tensor_copy_2
|
||||
struct Blockwise2dTensorCopy2
|
||||
{
|
||||
unsigned mThreadId0;
|
||||
unsigned mThreadId1;
|
||||
|
||||
__device__ blockwise_2d_tensor_copy_2()
|
||||
__device__ Blockwise2dTensorCopy2()
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! type is not float!\n");
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
|
||||
"wrong! stride is not 1!\n");
|
||||
|
||||
mThreadId0 = get_thread_local_1d_id() / ThreadPerDim1;
|
||||
mThreadId1 = get_thread_local_1d_id() - mThreadId0 * ThreadPerDim1;
|
||||
}
|
||||
|
||||
__device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
__device__ void Run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
|
||||
return;
|
||||
|
||||
@@ -227,22 +235,12 @@ struct blockwise_2d_tensor_copy_2
|
||||
for(unsigned d1v4loop = 0; d1v4loop < Dim1V4Loop; ++d1v4loop)
|
||||
{
|
||||
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
#if 1
|
||||
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<float4*>(p_src + sindex));
|
||||
|
||||
#else
|
||||
for(unsigned i = 0; i < 4; ++i)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// v2
|
||||
@@ -251,22 +249,11 @@ struct blockwise_2d_tensor_copy_2
|
||||
unsigned did1 =
|
||||
Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 + 2 * mThreadId1;
|
||||
|
||||
#if 1
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<float2*>(p_src + sindex));
|
||||
|
||||
#else
|
||||
for(unsigned i = 0; i < 2; ++i)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// v1
|
||||
@@ -310,22 +297,11 @@ struct blockwise_2d_tensor_copy_2
|
||||
{
|
||||
unsigned did1 = d1v4loop * 4 * ThreadPerDim1 + 4 * mThreadId1;
|
||||
|
||||
#if 1
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<float4*>(p_src + sindex));
|
||||
|
||||
#else
|
||||
for(unsigned i = 0; i < 4; ++i)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// v2
|
||||
@@ -334,22 +310,11 @@ struct blockwise_2d_tensor_copy_2
|
||||
unsigned did1 = Dim1V4Loop * 4 * ThreadPerDim1 + d1v2loop * 2 * ThreadPerDim1 +
|
||||
2 * mThreadId1;
|
||||
|
||||
#if 1
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<float2*>(p_src + sindex));
|
||||
|
||||
#else
|
||||
for(unsigned i = 0; i < 2; ++i)
|
||||
{
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1 + i);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1 + i);
|
||||
|
||||
p_dst[dindex] = p_src[sindex];
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// v1
|
||||
@@ -385,49 +350,104 @@ struct blockwise_2d_tensor_copy_2
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
struct blockwise_2d_tensor_copy_dummy_1
|
||||
// starting point need to be aligned to float4 or float2 or float
|
||||
// stride1 need to be 1 for both source and destination
|
||||
template <unsigned BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
unsigned DataPerRead>
|
||||
struct Blockwise2dTensorCopy3
|
||||
{
|
||||
unsigned mBegin;
|
||||
unsigned mSrcMyThreadOffset;
|
||||
unsigned mDstMyThreadOffset;
|
||||
|
||||
__device__ blockwise_2d_tensor_copy_dummy_1()
|
||||
__device__ Blockwise2dTensorCopy3()
|
||||
{
|
||||
constexpr unsigned n_total =
|
||||
make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace();
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned n_per_thread = n_total / BlockSize;
|
||||
static_assert(SrcDesc{}.GetStride(I1) == 1 && DstDesc{}.GetStride(I1) == 1,
|
||||
"wrong! only support stride1 == 1!\n");
|
||||
|
||||
mBegin = n_per_thread * get_thread_local_1d_id();
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
|
||||
|
||||
static_assert(L1 % DataPerRead == 0, "wrong! only support mod(L1, DataPerRead) == 0\n");
|
||||
|
||||
constexpr unsigned thread_per_d1 = L1 / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
static_assert(thread_per_d1 <= BlockSize,
|
||||
"wrong! not enough threads to cover L1 dimension\n");
|
||||
|
||||
const unsigned thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
|
||||
const unsigned thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
|
||||
|
||||
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
|
||||
mDstMyThreadOffset = DstDesc{}.Get1dIndex(thread_id_d0, thread_id_d1 * DataPerRead);
|
||||
}
|
||||
|
||||
__device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
__device__ void Run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr unsigned n_total =
|
||||
make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace();
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
constexpr unsigned n_per_thread = n_total / BlockSize;
|
||||
using Float2 = float2;
|
||||
using Float4 = float4;
|
||||
|
||||
for(unsigned i = 0; i < n_per_thread; ++i)
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = L1 / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
p_dst[mBegin + i] = p_src[mBegin + i];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
struct blockwise_2d_tensor_copy_dummy_2
|
||||
{
|
||||
__device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr unsigned n_total =
|
||||
make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace();
|
||||
|
||||
constexpr unsigned n_per_thread = n_total / BlockSize;
|
||||
|
||||
for(unsigned i = 0; i < n_per_thread; ++i)
|
||||
{
|
||||
unsigned index = get_thread_local_1d_id() + BlockSize * i;
|
||||
p_dst[index] = p_src[index];
|
||||
if(get_thread_local_1d_id() > num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
for(unsigned iloop = 0; iloop < nloop_d0; ++iloop)
|
||||
{
|
||||
if(DataPerRead == 1)
|
||||
{
|
||||
p_dst[mDstMyThreadOffset + iloop * dst_loop_stride] =
|
||||
p_src[mSrcMyThreadOffset + iloop * src_loop_stride];
|
||||
}
|
||||
else if(DataPerRead == 2)
|
||||
{
|
||||
*(reinterpret_cast<Float2*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
|
||||
*(reinterpret_cast<Float2*>(p_src + mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride));
|
||||
}
|
||||
else if(DataPerRead == 4)
|
||||
{
|
||||
*(reinterpret_cast<Float4*>(p_dst + mDstMyThreadOffset + iloop * dst_loop_stride)) =
|
||||
*(reinterpret_cast<Float4*>(p_src + mSrcMyThreadOffset +
|
||||
iloop * src_loop_stride));
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -200,9 +200,9 @@ blockwise_4d_tensor_copy_reorder_by_get_dst_from_src(SrcDesc,
|
||||
}
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
struct blockwise_4d_tensor_copy_1
|
||||
struct Blockwise4dTensorCopy1
|
||||
{
|
||||
__device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
__device__ void Run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr auto dst_from_src_reorder = Sequence<0, 1, 2, 3>{};
|
||||
|
||||
@@ -217,9 +217,9 @@ template <unsigned BlockSize,
|
||||
class DstDesc,
|
||||
class DstOpLengths,
|
||||
class GlobalLowerPads>
|
||||
struct blockwise_chwn_tensor_copy_with_padding
|
||||
struct BlockwiseChwnTensorCopyPadded
|
||||
{
|
||||
__device__ void run(Float* const __restrict__ p_src,
|
||||
__device__ void Run(Float* const __restrict__ p_src,
|
||||
unsigned c_block_data_begin,
|
||||
unsigned ho_block_data_begin,
|
||||
unsigned wo_block_data_begin,
|
||||
@@ -336,33 +336,4 @@ struct blockwise_chwn_tensor_copy_with_padding
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <unsigned BlockSize, class Float, class SrcDesc, class DstDesc, class SrcOpLengths>
|
||||
struct blockwise_4d_tensor_copy_dummy
|
||||
{
|
||||
unsigned mBegin;
|
||||
|
||||
__device__ blockwise_4d_tensor_copy_dummy()
|
||||
{
|
||||
constexpr unsigned n_total =
|
||||
make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace();
|
||||
|
||||
constexpr unsigned n_per_thread = n_total / BlockSize;
|
||||
|
||||
mBegin = n_per_thread * get_thread_local_1d_id();
|
||||
}
|
||||
|
||||
__device__ void run(Float* const __restrict__ p_src, Float* __restrict__ p_dst) const
|
||||
{
|
||||
constexpr unsigned n_total =
|
||||
make_ConstantTensorDescriptor(SrcOpLengths{}).GetElementSpace();
|
||||
|
||||
constexpr unsigned n_per_thread = n_total / BlockSize;
|
||||
|
||||
for(unsigned i = 0; i < n_per_thread; ++i)
|
||||
{
|
||||
p_dst[mBegin + i] = p_src[mBegin + i];
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
@@ -15,7 +15,7 @@ template <unsigned BlockSize,
|
||||
unsigned BatchPerThread,
|
||||
unsigned KPerThreadLoop,
|
||||
bool DistributeThreadAlongColumnFirst>
|
||||
struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
struct Blockwise1dStridedBatchedGemmBlockABlockBThreadC
|
||||
{
|
||||
unsigned mMyThreadOffsetA = 0;
|
||||
unsigned mMyThreadOffsetB = 0;
|
||||
@@ -27,7 +27,7 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
unsigned col_begin;
|
||||
};
|
||||
|
||||
__device__ blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c()
|
||||
__device__ Blockwise1dStridedBatchedGemmBlockABlockBThreadC()
|
||||
{
|
||||
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
|
||||
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
|
||||
@@ -117,7 +117,7 @@ struct blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC, class Accumulator>
|
||||
__device__ void run(FloatA* const p_a_block,
|
||||
__device__ void Run(FloatA* const p_a_block,
|
||||
FloatB* const p_b_block,
|
||||
FloatC* p_c_thread,
|
||||
Accumulator f_accum) const
|
||||
@@ -230,7 +230,7 @@ template <unsigned BlockSize,
|
||||
unsigned MThreadPerCluster,
|
||||
unsigned NThreadPerCluster,
|
||||
bool DistributeThreadAlongColumnFirst>
|
||||
struct blockwise_gemm_block_a_block_b_thread_c
|
||||
struct BlockwiseGemmBlockABlockBThreadC
|
||||
{
|
||||
unsigned mMyThreadOffsetA = 0;
|
||||
unsigned mMyThreadOffsetB = 0;
|
||||
@@ -241,7 +241,7 @@ struct blockwise_gemm_block_a_block_b_thread_c
|
||||
unsigned col_begin;
|
||||
};
|
||||
|
||||
__device__ blockwise_gemm_block_a_block_b_thread_c()
|
||||
__device__ BlockwiseGemmBlockABlockBThreadC()
|
||||
{
|
||||
const auto a_block_mtx = BlockMatrixA{}; // constexpr doesn't compile
|
||||
const auto b_block_mtx = BlockMatrixB{}; // constexpr doesn't compile
|
||||
@@ -360,7 +360,7 @@ struct blockwise_gemm_block_a_block_b_thread_c
|
||||
}
|
||||
|
||||
template <class FloatA, class FloatB, class FloatC, class Accumulator>
|
||||
__device__ void run(FloatA* const p_a_block,
|
||||
__device__ void Run(FloatA* const p_a_block,
|
||||
FloatB* const p_b_block,
|
||||
FloatC* p_c_thread,
|
||||
Accumulator f_accum) const
|
||||
|
||||
@@ -122,25 +122,25 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
|
||||
#endif
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(in_block_global_desc),
|
||||
decltype(in_block_desc),
|
||||
decltype(in_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_block_global_desc),
|
||||
decltype(in_block_desc),
|
||||
decltype(in_block_desc.GetLengths())>{};
|
||||
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_block_global_desc),
|
||||
decltype(wei_block_desc),
|
||||
decltype(wei_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_block_global_desc),
|
||||
decltype(wei_block_desc),
|
||||
decltype(wei_block_desc.GetLengths())>{};
|
||||
|
||||
constexpr auto blockwise_out_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(out_block_desc),
|
||||
decltype(out_block_global_desc),
|
||||
decltype(out_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(out_block_desc),
|
||||
decltype(out_block_global_desc),
|
||||
decltype(out_block_desc.GetLengths())>{};
|
||||
|
||||
// set output tensor in LDS to 0
|
||||
blockwise_4d_tensor_set_zero<BlockSize>(out_block_desc, p_out_block);
|
||||
@@ -149,14 +149,14 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
|
||||
c_block_work_begin += CPerBlock)
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin,
|
||||
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_work_begin,
|
||||
c_block_work_begin,
|
||||
hi_block_work_begin,
|
||||
wi_block_work_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.run(
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global + wei_global_desc.Get1dIndex(k_block_work_begin, c_block_work_begin, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
@@ -179,7 +179,7 @@ __global__ void gridwise_direct_convolution_1(InGlobalDesc,
|
||||
}
|
||||
|
||||
// copy output tensor from LDS to device mem
|
||||
blockwise_out_copy.run(p_out_block,
|
||||
blockwise_out_copy.Run(p_out_block,
|
||||
p_out_global + out_global_desc.Get1dIndex(n_block_work_begin,
|
||||
k_block_work_begin,
|
||||
ho_block_work_begin,
|
||||
|
||||
@@ -145,18 +145,18 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
#endif
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(in_global_desc),
|
||||
decltype(in_block_desc),
|
||||
decltype(in_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_global_desc),
|
||||
decltype(in_block_desc),
|
||||
decltype(in_block_desc.GetLengths())>{};
|
||||
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_global_desc),
|
||||
decltype(wei_block_desc),
|
||||
decltype(wei_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_global_desc),
|
||||
decltype(wei_block_desc),
|
||||
decltype(wei_block_desc.GetLengths())>{};
|
||||
|
||||
// set threadwise output tensor to 0
|
||||
threadwise_4d_tensor_set_zero(out_thread_desc, p_out_thread);
|
||||
@@ -165,14 +165,14 @@ __global__ void gridwise_direct_convolution_2(InGlobalDesc,
|
||||
c_block_data_begin += CPerBlock, __syncthreads())
|
||||
{
|
||||
// copy input tensor to LDS
|
||||
blockwise_in_copy.run(p_in_global + in_global_desc.Get1dIndex(n_block_data_begin,
|
||||
blockwise_in_copy.Run(p_in_global + in_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
p_in_block);
|
||||
|
||||
// copy weight tensor to LDS
|
||||
blockwise_wei_copy.run(
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global + wei_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
p_wei_block);
|
||||
|
||||
|
||||
@@ -106,19 +106,19 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
// blockwise copy
|
||||
// input: format is [C, Hi, Wi, N]
|
||||
constexpr auto blockwise_in_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths())>{};
|
||||
|
||||
// weight: format is [S,R,C,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_csrk_global_desc),
|
||||
decltype(wei_csrk_block_desc),
|
||||
decltype(wei_csrk_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_csrk_global_desc),
|
||||
decltype(wei_csrk_block_desc),
|
||||
decltype(wei_csrk_block_desc.GetLengths())>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -140,21 +140,20 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(
|
||||
I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
@@ -183,12 +182,12 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global_block_begin, p_in_block);
|
||||
blockwise_in_copy.Run(p_in_global_block_begin, p_in_block);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global_block_begin, p_wei_block);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
@@ -200,7 +199,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn(InGlobalDesc,
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_batch_gemm.run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
@@ -136,39 +136,38 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
|
||||
#endif
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
blockwise_chwn_tensor_copy_with_padding<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
LowerPads>{};
|
||||
BlockwiseChwnTensorCopyPadded<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
LowerPads>{};
|
||||
|
||||
#if 1
|
||||
// weight: format is [C,S,R,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_csrk_global_desc),
|
||||
decltype(wei_csrk_block_desc),
|
||||
decltype(wei_csrk_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_csrk_global_desc),
|
||||
decltype(wei_csrk_block_desc),
|
||||
decltype(wei_csrk_block_desc.GetLengths())>{};
|
||||
#elif 1
|
||||
// weight: format is [C*S*R,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 1
|
||||
// weight: format is [C*S*R,K]
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
@@ -191,21 +190,20 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
|
||||
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(
|
||||
I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
@@ -229,7 +227,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global,
|
||||
blockwise_in_copy.Run(p_in_global,
|
||||
c_block_data_begin,
|
||||
ho_block_data_begin,
|
||||
wo_block_data_begin,
|
||||
@@ -243,7 +241,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global_block_begin, p_wei_block);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
@@ -255,7 +253,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_batch_gemm.run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
@@ -136,17 +136,17 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
#endif
|
||||
|
||||
constexpr auto blockwise_in_copy =
|
||||
blockwise_chwn_tensor_copy_with_padding<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
LowerPads>{};
|
||||
BlockwiseChwnTensorCopyPadded<BlockSize,
|
||||
Float,
|
||||
decltype(in_chwn_global_desc),
|
||||
decltype(in_chwn_block_desc),
|
||||
decltype(in_chwn_block_desc.GetLengths()),
|
||||
LowerPads>{};
|
||||
|
||||
#if 0
|
||||
// weight: format is [C,S,R,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_csrk_global_desc),
|
||||
decltype(wei_csrk_block_desc),
|
||||
@@ -154,21 +154,20 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
#elif 0
|
||||
// weight: format is [C*S*R,K]
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 1
|
||||
// weight: format is [C*S*R,K]
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
@@ -191,21 +190,20 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(
|
||||
I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
@@ -229,7 +227,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
|
||||
// prelog: load data
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global,
|
||||
blockwise_in_copy.Run(p_in_global,
|
||||
0,
|
||||
ho_block_data_begin,
|
||||
wo_block_data_begin,
|
||||
@@ -241,7 +239,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
w_block_pad_up);
|
||||
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global_block_begin, p_wei_block_0);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block_0);
|
||||
|
||||
p_wei_global_block_begin += CPerBlock * wei_ek_global_desc.GetStride(I0);
|
||||
|
||||
@@ -263,7 +261,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
// preload next data
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global,
|
||||
blockwise_in_copy.Run(p_in_global,
|
||||
c_block_data_begin,
|
||||
ho_block_data_begin,
|
||||
wo_block_data_begin,
|
||||
@@ -277,7 +275,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global_block_begin, p_wei_block_next);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_begin, p_wei_block_next);
|
||||
#endif
|
||||
|
||||
// a series of batched GEMM
|
||||
@@ -287,7 +285,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_batch_gemm.run(p_wei_block_now +
|
||||
blockwise_batch_gemm.Run(p_wei_block_now +
|
||||
wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block_now + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_out_thread,
|
||||
@@ -310,7 +308,7 @@ __global__ void gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded_lds_p
|
||||
{
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_batch_gemm.run(p_wei_block_now +
|
||||
blockwise_batch_gemm.Run(p_wei_block_now +
|
||||
wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block_now + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_out_thread,
|
||||
|
||||
@@ -127,21 +127,20 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
|
||||
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(
|
||||
I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
@@ -175,15 +174,15 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
|
||||
#else
|
||||
// input: global mem to LDS,
|
||||
// no format conversion, this is wrong, for performance study only!
|
||||
blockwise_4d_tensor_copy<BlockSize>(in_nchw_global_desc,
|
||||
p_in_global +
|
||||
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
in_nchw_block_desc,
|
||||
p_in_block,
|
||||
in_nchw_block_desc.GetLengths());
|
||||
Blockwise4dTensorCopy<BlockSize>(in_nchw_global_desc,
|
||||
p_in_global +
|
||||
in_nchw_global_desc.Get1dIndex(n_block_data_begin,
|
||||
c_block_data_begin,
|
||||
hi_block_data_begin,
|
||||
wi_block_data_begin),
|
||||
in_nchw_block_desc,
|
||||
p_in_block,
|
||||
in_nchw_block_desc.GetLengths());
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
@@ -200,7 +199,7 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
|
||||
#else
|
||||
// weight: global mem to LDS,
|
||||
// no format conversion, this is wrong, for performance study only!
|
||||
blockwise_4d_tensor_copy<BlockSize>(
|
||||
Blockwise4dTensorCopy<BlockSize>(
|
||||
wei_kcsr_global_desc,
|
||||
p_wei_global +
|
||||
wei_kcsr_global_desc.Get1dIndex(k_block_data_begin, c_block_data_begin, 0, 0),
|
||||
@@ -219,7 +218,7 @@ gridwise_implicit_gemm_convolution_1_nchw_kcsr(InGlobalDesc,
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
@@ -109,11 +109,11 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
// blockwise copy
|
||||
// wei: format is [S,R,C,K], no conversion needed
|
||||
constexpr auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_srck_global_desc),
|
||||
decltype(wei_srck_block_desc),
|
||||
decltype(wei_srck_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_srck_global_desc),
|
||||
decltype(wei_srck_block_desc),
|
||||
decltype(wei_srck_block_desc.GetLengths())>{};
|
||||
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
@@ -133,21 +133,20 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
Number<KPerThread>{}, Number<WoPerThread * NPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_batch_gemm =
|
||||
blockwise_1d_strided_batched_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(
|
||||
I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
Blockwise1dStridedBatchedGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxwn_block_mtx_desc),
|
||||
decltype(c_kxwn_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
0,
|
||||
in_chwn_block_desc.GetStride(I1),
|
||||
out_hkwn_thread_desc.GetStride(I0),
|
||||
HoPerBlock,
|
||||
HoPerThread,
|
||||
CPerThread,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_chwn_block_desc.GetElementSpace();
|
||||
@@ -183,7 +182,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
// format is [S,R,C,K], no conversion needed
|
||||
blockwise_wei_copy.run(p_wei_global + wei_srck_global_desc.Get1dIndex(
|
||||
blockwise_wei_copy.Run(p_wei_global + wei_srck_global_desc.Get1dIndex(
|
||||
0, 0, c_block_data_begin, k_block_data_begin),
|
||||
p_wei_block);
|
||||
#endif
|
||||
@@ -197,7 +196,7 @@ gridwise_implicit_gemm_convolution_1_nchw_srck_nkhw(InGlobalDesc,
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_batch_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_batch_gemm.Run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + in_chwn_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
@@ -25,7 +25,9 @@ template <unsigned GridSize,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
unsigned WeiBlockCopyThreadPerDim1>
|
||||
unsigned WeiBlockCopyThreadPerDim1,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead>
|
||||
__global__ void
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
@@ -117,40 +119,52 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*S*R,KPerBlock]
|
||||
#if 0
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
@@ -170,18 +184,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_gemm =
|
||||
blockwise_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
true>{};
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace();
|
||||
@@ -208,10 +221,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
__syncthreads())
|
||||
{
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global_block_offset, p_in_block);
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global_block_offset, p_wei_block);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
@@ -222,7 +235,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
@@ -283,10 +296,8 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
#endif
|
||||
if(n_data < N && h_data < Ho && w_data < Wo)
|
||||
{
|
||||
#if 1
|
||||
p_out_global[out_knhw_global_desc.Get1dIndex(k_data, n_data, h_data, w_data)] =
|
||||
p_out_thread[out_kb_thread_desc.Get1dIndex(k, b)];
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,7 +25,9 @@ template <unsigned GridSize,
|
||||
unsigned InBlockCopyThreadPerDim0,
|
||||
unsigned InBlockCopyThreadPerDim1,
|
||||
unsigned WeiBlockCopyThreadPerDim0,
|
||||
unsigned WeiBlockCopyThreadPerDim1>
|
||||
unsigned WeiBlockCopyThreadPerDim1,
|
||||
unsigned InBlockCopyDataPerRead,
|
||||
unsigned WeiBlockCopyDataPerRead>
|
||||
__global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline(
|
||||
InGlobalDesc,
|
||||
Float* const __restrict__ p_in_global,
|
||||
@@ -117,40 +119,52 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [CPerBlock*S*R,KPerBlock]
|
||||
#if 0
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
#elif 1
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyThreadPerDim0,
|
||||
WeiBlockCopyThreadPerDim1>{};
|
||||
const auto blockwise_wei_copy = Blockwise2dTensorCopy3<BlockSize,
|
||||
Float,
|
||||
decltype(wei_ek_global_desc),
|
||||
decltype(wei_ek_block_desc),
|
||||
decltype(wei_ek_block_desc.GetLengths()),
|
||||
WeiBlockCopyDataPerRead>{};
|
||||
#endif
|
||||
|
||||
// a series of blockwise GEMM
|
||||
@@ -170,18 +184,17 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_gemm =
|
||||
blockwise_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
true>{};
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace();
|
||||
@@ -205,10 +218,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
|
||||
// prelog : preload data
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global_block_offset, p_in_block_0);
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_0);
|
||||
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global_block_offset, p_wei_block_0);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_0);
|
||||
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0);
|
||||
|
||||
@@ -234,10 +247,10 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
Float* p_wei_block_next = even_loop ? p_wei_block_1 : p_wei_block_0;
|
||||
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global_block_offset, p_in_block_next);
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block_next);
|
||||
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global_block_offset, p_wei_block_next);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block_next);
|
||||
|
||||
// a series of GEMM
|
||||
for(unsigned s = 0; s < S; ++s)
|
||||
@@ -246,7 +259,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_gemm.Run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block_now + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
@@ -268,7 +281,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_gemm.Run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block_now + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
@@ -110,30 +110,29 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 1
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#endif
|
||||
|
||||
// blockwise wei copy
|
||||
// format is [S,R,CPerBlock,KPerBlock]
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_srck_global_desc),
|
||||
decltype(wei_srck_block_desc),
|
||||
decltype(wei_srck_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_srck_global_desc),
|
||||
decltype(wei_srck_block_desc),
|
||||
decltype(wei_srck_block_desc.GetLengths())>{};
|
||||
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
@@ -152,18 +151,17 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_gemm =
|
||||
blockwise_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
true>{};
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmThreadPerClusterRow,
|
||||
GemmThreadPerClusterColumn,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace();
|
||||
@@ -191,12 +189,12 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
{
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global_block_offset, p_in_block);
|
||||
blockwise_in_copy.Run(p_in_global_block_offset, p_in_block);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global_block_offset, p_wei_block);
|
||||
blockwise_wei_copy.Run(p_wei_global_block_offset, p_wei_block);
|
||||
#endif
|
||||
|
||||
__syncthreads();
|
||||
@@ -209,7 +207,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_gemm.Run(p_wei_block + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
@@ -110,20 +110,19 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 1
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 1
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths()),
|
||||
InBlockCopyThreadPerDim0,
|
||||
InBlockCopyThreadPerDim1>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy =
|
||||
blockwise_2d_tensor_copy_dummy_2<BlockSize,
|
||||
@@ -137,11 +136,11 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
// format is [S,R,CPerBlock,KPerBlock]
|
||||
#if 1
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_srck_global_desc),
|
||||
decltype(wei_srck_block_desc),
|
||||
decltype(wei_srck_block_desc.GetLengths())>{};
|
||||
Blockwise4dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(wei_srck_global_desc),
|
||||
decltype(wei_srck_block_desc),
|
||||
decltype(wei_srck_block_desc.GetLengths())>{};
|
||||
#else
|
||||
const auto blockwise_wei_copy =
|
||||
blockwise_4d_tensor_copy_dummy<BlockSize,
|
||||
@@ -168,18 +167,17 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
const auto c_kxb_thread_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<KPerThread>{}, Number<BPerThread>{}); // constexpr doesn't compile
|
||||
|
||||
const auto blockwise_gemm =
|
||||
blockwise_gemm_block_a_block_b_thread_c<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmRowThreadPerCluster,
|
||||
GemmColumnThreadPerCluster,
|
||||
true>{};
|
||||
const auto blockwise_gemm = BlockwiseGemmBlockABlockBThreadC<BlockSize,
|
||||
decltype(a_cxk_block_mtx_desc),
|
||||
decltype(b_cxb_block_mtx_desc),
|
||||
decltype(c_kxb_thread_mtx_desc),
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
CPerThread,
|
||||
GemmRowThreadPerCluster,
|
||||
GemmColumnThreadPerCluster,
|
||||
true>{};
|
||||
|
||||
// LDS
|
||||
constexpr unsigned in_block_size = in_cb_block_desc.GetElementSpace();
|
||||
@@ -201,13 +199,13 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
// prelog: load data
|
||||
#if 1
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin),
|
||||
blockwise_in_copy.Run(p_in_global + in_cb_global_desc.Get1dIndex(0, b_block_data_begin),
|
||||
p_in_block_0);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(
|
||||
blockwise_wei_copy.Run(
|
||||
p_wei_global + wei_srck_global_desc.Get1dIndex(0, 0, 0, k_block_data_begin), p_wei_block_0);
|
||||
#endif
|
||||
|
||||
@@ -227,14 +225,14 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
#if 1
|
||||
// preload next data
|
||||
// input: global mem to LDS,
|
||||
blockwise_in_copy.run(p_in_global + in_cb_global_desc.Get1dIndex(
|
||||
blockwise_in_copy.Run(p_in_global + in_cb_global_desc.Get1dIndex(
|
||||
c_block_data_begin + CPerBlock, b_block_data_begin),
|
||||
p_in_block_next);
|
||||
#endif
|
||||
|
||||
#if 1
|
||||
// weight: global mem to LDS,
|
||||
blockwise_wei_copy.run(p_wei_global +
|
||||
blockwise_wei_copy.Run(p_wei_global +
|
||||
wei_srck_global_desc.Get1dIndex(
|
||||
0, 0, c_block_data_begin + CPerBlock, k_block_data_begin),
|
||||
p_wei_block_next);
|
||||
@@ -247,7 +245,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block_now + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_gemm.Run(p_wei_block_now + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block_now + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
@@ -269,7 +267,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw_lds_pipeline
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.run(p_wei_block_now + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_gemm.Run(p_wei_block_now + wei_srck_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
p_in_block_now + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
Reference in New Issue
Block a user