mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
bug fixes
This commit is contained in:
@@ -389,8 +389,8 @@ int main()
|
||||
constexpr unsigned S = 3;
|
||||
constexpr unsigned R = 3;
|
||||
|
||||
constexpr unsigned HPad = 1;
|
||||
constexpr unsigned WPad = 1;
|
||||
constexpr unsigned HPad = 0;
|
||||
constexpr unsigned WPad = 0;
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned N = 64;
|
||||
@@ -593,8 +593,6 @@ int main()
|
||||
device_implicit_gemm_convolution_2_cnhw_srck_knhw
|
||||
#elif 1
|
||||
device_implicit_gemm_convolution_2_cnhw_csrk_knhw
|
||||
#elif 0
|
||||
device_winograd_convolution
|
||||
#endif
|
||||
(in_nchw_desc, in_nchw, wei_kcsr_desc, wei_kcsr, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
|
||||
|
||||
@@ -67,7 +67,30 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
|
||||
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
|
||||
|
||||
#if 1
|
||||
#if 0
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned WeiBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned InBlockCopyDataPerRead = 2;
|
||||
constexpr unsigned WeiBlockCopyDataPerRead = 4;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
// 1x1, 28x28
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
@@ -120,7 +143,7 @@ void device_implicit_gemm_convolution_2_cnhw_csrk_knhw(InDesc,
|
||||
|
||||
#if 1
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw
|
||||
#elif 0
|
||||
#elif 1
|
||||
gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
#endif
|
||||
<GridSize,
|
||||
|
||||
@@ -68,64 +68,55 @@ void device_implicit_gemm_convolution_2_cnhw_srck_knhw(InDesc,
|
||||
Tensor<T> out_knhw(make_TensorDescriptor(out_knhw_desc));
|
||||
|
||||
#if 0
|
||||
constexpr unsigned BPerBlock = 256;
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 1;
|
||||
constexpr unsigned CPerBlock = 1;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 1;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 1;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 1;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 1;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned BlockSize = 32;
|
||||
#elif 0
|
||||
#elif 1
|
||||
// 3x3, 34x34
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
constexpr unsigned CPerBlock = 4;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 0
|
||||
constexpr unsigned BPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 2;
|
||||
|
||||
constexpr unsigned BPerThread = 8;
|
||||
constexpr unsigned KPerThread = 8;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 4;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 2;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 64;
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
#elif 1
|
||||
// 1x1, 28x28
|
||||
constexpr unsigned BPerBlock = 64;
|
||||
constexpr unsigned KPerBlock = 128;
|
||||
constexpr unsigned KPerBlock = 64;
|
||||
constexpr unsigned CPerBlock = 8;
|
||||
|
||||
constexpr unsigned BPerThread = 4;
|
||||
constexpr unsigned KPerThread = 16;
|
||||
constexpr unsigned CPerThread = 2;
|
||||
constexpr unsigned CPerThread = 1;
|
||||
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 8;
|
||||
constexpr unsigned GemmThreadPerColumnPerCluster = 4;
|
||||
constexpr unsigned GemmThreadPerRowPerCluster = 8;
|
||||
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 8;
|
||||
constexpr unsigned InBlockCopyThreadPerDim0 = 4;
|
||||
constexpr unsigned InBlockCopyThreadPerDim1 = 16;
|
||||
|
||||
constexpr unsigned BlockSize = 128;
|
||||
constexpr unsigned BlockSize = 64;
|
||||
#endif
|
||||
|
||||
constexpr unsigned GridSize =
|
||||
|
||||
@@ -15,6 +15,24 @@ __host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <unsigned L0, unsigned L1, unsigned Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1>,
|
||||
Number<Align>)
|
||||
{
|
||||
constexpr unsigned L1_align = Align * ((L1 + Align - 1) / Align);
|
||||
return Sequence<L1_align, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3, unsigned Align>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_aligned(Sequence<L0, L1, L2, L3>,
|
||||
Number<Align>)
|
||||
{
|
||||
constexpr unsigned L3_align = Align * ((L3 + Align - 1) / Align);
|
||||
return Sequence<L1 * L2 * L3_align, L2 * L3_align, L3_align, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
|
||||
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
|
||||
@@ -170,6 +188,13 @@ __host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths, Stride
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class Lengths, unsigned Align>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor_aligned(Lengths, Number<Align>)
|
||||
{
|
||||
using Strides = decltype(calculate_default_strides_aligned(Lengths{}, Number<Align>{}));
|
||||
return ConstantTensorDescriptor<Lengths, Strides>{};
|
||||
}
|
||||
|
||||
template <class TDesc>
|
||||
__host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
{
|
||||
|
||||
@@ -203,6 +203,9 @@ struct Blockwise2dTensorCopy2
|
||||
{
|
||||
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
|
||||
|
||||
using Float4 = float4;
|
||||
using Float2 = float2;
|
||||
|
||||
if(get_thread_local_1d_id() >= ThreadPerDim0 * ThreadPerDim1)
|
||||
return;
|
||||
|
||||
@@ -212,18 +215,28 @@ struct Blockwise2dTensorCopy2
|
||||
constexpr auto src_desc = SrcDesc{};
|
||||
constexpr auto dst_desc = DstDesc{};
|
||||
|
||||
// check alignment
|
||||
constexpr bool align_v4 =
|
||||
src_desc.GetStride(I0) % 4 == 0 && dst_desc.GetStride(I0) % 4 == 0;
|
||||
|
||||
constexpr bool align_v2 =
|
||||
src_desc.GetStride(I0) % 2 == 0 && dst_desc.GetStride(I0) % 2 == 0;
|
||||
|
||||
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned Dim0Loop = L0 / ThreadPerDim0;
|
||||
constexpr bool d0_has_tail = (L0 > ThreadPerDim0 * Dim0Loop);
|
||||
|
||||
constexpr unsigned Dim1V4Loop = L1 / (ThreadPerDim1 * 4);
|
||||
constexpr unsigned Dim1V4Loop = align_v4 ? L1 / (ThreadPerDim1 * 4) : 0;
|
||||
|
||||
constexpr unsigned Dim1V2Loop =
|
||||
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2);
|
||||
align_v2 ? (L1 - Dim1V4Loop * (ThreadPerDim1 * 4)) / (ThreadPerDim1 * 2) : 0;
|
||||
|
||||
constexpr unsigned Dim1V1Loop =
|
||||
(L1 - Dim1V4Loop * (ThreadPerDim1 * 4) - Dim1V2Loop * (ThreadPerDim1 * 2)) /
|
||||
ThreadPerDim1;
|
||||
|
||||
constexpr bool d1_has_tail =
|
||||
(L1 > ThreadPerDim1 * (4 * Dim1V4Loop + 2 * Dim1V2Loop + Dim1V1Loop));
|
||||
|
||||
@@ -239,8 +252,8 @@ struct Blockwise2dTensorCopy2
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<float4*>(p_src + sindex));
|
||||
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<Float4*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v2
|
||||
@@ -252,8 +265,8 @@ struct Blockwise2dTensorCopy2
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<float2*>(p_src + sindex));
|
||||
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<Float2*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v1
|
||||
@@ -300,8 +313,8 @@ struct Blockwise2dTensorCopy2
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<float4*>(p_src + sindex));
|
||||
*(reinterpret_cast<Float4*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<Float4*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v2
|
||||
@@ -313,8 +326,8 @@ struct Blockwise2dTensorCopy2
|
||||
const unsigned sindex = src_desc.Get1dIndex(did0, did1);
|
||||
const unsigned dindex = dst_desc.Get1dIndex(did0, did1);
|
||||
|
||||
*(reinterpret_cast<float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<float2*>(p_src + sindex));
|
||||
*(reinterpret_cast<Float2*>(p_dst + dindex)) =
|
||||
*(reinterpret_cast<Float2*>(p_src + sindex));
|
||||
}
|
||||
|
||||
// v1
|
||||
@@ -356,7 +369,7 @@ template <unsigned BlockSize,
|
||||
class Float,
|
||||
class SrcDesc,
|
||||
class DstDesc,
|
||||
class SrcOpLengths,
|
||||
class CopyLengths,
|
||||
unsigned DataPerRead>
|
||||
struct Blockwise2dTensorCopy3
|
||||
{
|
||||
@@ -374,16 +387,23 @@ struct Blockwise2dTensorCopy3
|
||||
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
|
||||
"wrong! only support DataPerRead == 1, 2 or 4!\n");
|
||||
|
||||
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
|
||||
static_assert(SrcDesc{}.GetStride(I0) % DataPerRead == 0 &&
|
||||
DstDesc{}.GetStride(I0) % DataPerRead == 0,
|
||||
"src and dst stride should be multiple of DataPerRead to keep alignment");
|
||||
|
||||
static_assert(L1 % DataPerRead == 0, "wrong! only support mod(L1, DataPerRead) == 0\n");
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = L1 / DataPerRead;
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
static_assert(thread_per_d1 <= BlockSize,
|
||||
"wrong! not enough threads to cover L1 dimension\n");
|
||||
// we allow out-of-bound read from src in D1 dimension,
|
||||
// but we need to make sure dst stride is big enough,
|
||||
// so that the out-of-bound write won't overwrite next line
|
||||
static_assert(thread_per_d1 * DataPerRead <= DstDesc{}.GetStride(I0),
|
||||
"wrong! out-of-bound write will overwrite next line!\n");
|
||||
|
||||
static_assert(thread_per_d0 >= 1, "wrong! not enough threads to cover L1 dimension\n");
|
||||
|
||||
const unsigned thread_id_d0 = get_thread_local_1d_id() / thread_per_d1;
|
||||
const unsigned thread_id_d1 = get_thread_local_1d_id() - thread_id_d0 * thread_per_d1;
|
||||
@@ -402,17 +422,17 @@ struct Blockwise2dTensorCopy3
|
||||
constexpr auto I0 = Number<0>{};
|
||||
constexpr auto I1 = Number<1>{};
|
||||
|
||||
constexpr unsigned L0 = SrcOpLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = SrcOpLengths{}.Get(I1);
|
||||
constexpr unsigned L0 = CopyLengths{}.Get(I0);
|
||||
constexpr unsigned L1 = CopyLengths{}.Get(I1);
|
||||
|
||||
constexpr unsigned thread_per_d1 = L1 / DataPerRead;
|
||||
constexpr unsigned thread_per_d1 = (L1 + DataPerRead - 1) / DataPerRead;
|
||||
constexpr unsigned thread_per_d0 = BlockSize / thread_per_d1;
|
||||
|
||||
constexpr unsigned num_active_thread = thread_per_d0 * thread_per_d1;
|
||||
|
||||
if(BlockSize > num_active_thread)
|
||||
{
|
||||
if(get_thread_local_1d_id() > num_active_thread)
|
||||
if(get_thread_local_1d_id() >= num_active_thread)
|
||||
{
|
||||
return;
|
||||
}
|
||||
@@ -420,8 +440,6 @@ struct Blockwise2dTensorCopy3
|
||||
|
||||
constexpr unsigned nloop_d0 = L0 / thread_per_d0;
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
constexpr unsigned src_loop_stride = SrcDesc{}.GetStride(I0) * thread_per_d0;
|
||||
constexpr unsigned dst_loop_stride = DstDesc{}.GetStride(I0) * thread_per_d0;
|
||||
|
||||
@@ -450,6 +468,8 @@ struct Blockwise2dTensorCopy3
|
||||
}
|
||||
}
|
||||
|
||||
constexpr bool has_tail_d0 = (L0 > nloop_d0 * thread_per_d0);
|
||||
|
||||
if(has_tail_d0)
|
||||
{
|
||||
constexpr unsigned tail_d0 = L0 - nloop_d0 * thread_per_d0;
|
||||
|
||||
@@ -173,7 +173,7 @@ gridwise_implicit_gemm_convolution_1_chwn_csrk_khwn_padded(Float* const __restri
|
||||
// a series of blockwise batched GEMM
|
||||
// C_matrix += transpose(A_matrix) * B_matrix
|
||||
// A_matrix and B_matrix saved in LDS, C_matrix saved in register
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[S,R,C,K]
|
||||
// A_matrix[C,K] is a sub-matrix of wei_block[C,S,R,K]
|
||||
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
|
||||
// C_matrix[K,Wo*N] is a sub-matrix of out_block[Ho,K,Wo,N]
|
||||
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
|
||||
@@ -70,35 +70,32 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
const unsigned k_block_data_begin = k_block_work_id * KPerBlock;
|
||||
const unsigned b_block_data_begin = b_block_work_id * BPerBlock;
|
||||
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0)
|
||||
{
|
||||
printf("K %u B %u, BGhostRead %u\n", K, B, BGhostRead);
|
||||
|
||||
printf("%u %u, KBlockWork %u BBlockWork %u, k_block_data_begin %u b_block_data_begin %u\n",
|
||||
get_block_1d_id(),
|
||||
get_thread_local_1d_id(),
|
||||
KBlockWork,
|
||||
BBlockWork,
|
||||
k_block_data_begin,
|
||||
b_block_data_begin);
|
||||
}
|
||||
#endif
|
||||
|
||||
// flattend (2d) tensor view of gridwise input
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
|
||||
constexpr auto in_cb_global_desc = make_ConstantTensorDescriptor(Sequence<C, B>{});
|
||||
constexpr auto wei_ek_global_desc = make_ConstantTensorDescriptor(Sequence<C * S * R, K>{});
|
||||
|
||||
// tensor view of blockwise input and weight
|
||||
#if 0
|
||||
constexpr auto in_cb_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, BPerBlock + BGhostRead>{});
|
||||
#else
|
||||
constexpr auto in_cb_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, BPerBlock + BGhostRead>{}, Number<InBlockCopyDataPerRead>{});
|
||||
#endif
|
||||
|
||||
#if 0
|
||||
constexpr auto wei_ek_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock * S * R, KPerBlock>{});
|
||||
|
||||
constexpr auto wei_csrk_block_desc =
|
||||
make_ConstantTensorDescriptor(Sequence<CPerBlock, S, R, KPerBlock>{});
|
||||
#else
|
||||
constexpr auto wei_ek_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock * S * R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
|
||||
constexpr auto wei_csrk_block_desc = make_ConstantTensorDescriptor_aligned(
|
||||
Sequence<CPerBlock, S, R, KPerBlock>{}, Number<WeiBlockCopyDataPerRead>{});
|
||||
#endif
|
||||
|
||||
// tensor view of threadwise output in register
|
||||
constexpr auto out_kb_thread_desc =
|
||||
@@ -107,8 +104,16 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
#if 0
|
||||
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
|
||||
{
|
||||
print_ConstantTensorDescriptor(in_cnhw_global_desc, "in_cnhw_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_csrk_global_desc, "wei_csrk_global_desc");
|
||||
print_ConstantTensorDescriptor(out_knhw_global_desc, "out_knhw_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_cb_global_desc, "in_cb_global_desc");
|
||||
print_ConstantTensorDescriptor(wei_ek_global_desc, "wei_ek_global_desc");
|
||||
|
||||
print_ConstantTensorDescriptor(in_cb_block_desc, "in_cb_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_csrk_block_desc, "wei_csrk_block_desc");
|
||||
print_ConstantTensorDescriptor(wei_ek_block_desc, "wei_ek_block_desc");
|
||||
print_ConstantTensorDescriptor(out_kb_thread_desc, "out_kb_thread_desc");
|
||||
|
||||
printf("KPerBlock %u\n", KPerBlock);
|
||||
@@ -120,10 +125,10 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
Float,
|
||||
decltype(in_cb_global_desc),
|
||||
decltype(in_cb_block_desc),
|
||||
decltype(in_cb_block_desc.GetLengths())>{};
|
||||
#elif 0
|
||||
const auto blockwise_in_copy = Blockwise2dTensorCopy2<BlockSize,
|
||||
Float,
|
||||
@@ -170,11 +175,13 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
// a series of blockwise GEMM
|
||||
// c_mtx += transpose(a_mtx) * b_mtx
|
||||
// a_mtx and b_mtx saved in LDS, c_mtx saved in register
|
||||
// a_mtx[C,K] is a sub-matrix of wei_block[S,R,C,K]
|
||||
// a_mtx[C,K] is a sub-matrix of wei_block[C,S,R,K]
|
||||
// b_mtx[C,B] is a subset of in_block[C,B + BGhostRead]
|
||||
// c_mtx[K,B] is out_block[K,B]
|
||||
const auto a_cxk_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{}, Number<KPerBlock>{}); // constexpr doesn't compile
|
||||
Number<CPerBlock>{},
|
||||
Number<KPerBlock>{},
|
||||
Number<wei_csrk_block_desc.GetStride(I0)>{}); // constexpr doesn't compile
|
||||
|
||||
const auto b_cxb_block_mtx_desc = make_ConstantMatrixDescriptor(
|
||||
Number<CPerBlock>{},
|
||||
@@ -217,7 +224,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
|
||||
for(unsigned c_block_data_begin = 0; c_block_data_begin < C; c_block_data_begin += CPerBlock,
|
||||
p_in_global_block_offset += CPerBlock * in_cb_global_desc.GetStride(I0),
|
||||
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I2),
|
||||
p_wei_global_block_offset += CPerBlock * wei_csrk_global_desc.GetStride(I0),
|
||||
__syncthreads())
|
||||
{
|
||||
// input: global mem to LDS,
|
||||
@@ -233,9 +240,9 @@ gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw(InGlobalDesc,
|
||||
{
|
||||
for(unsigned r = 0; r < R; ++r)
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
auto f_accum = [](auto& acc, const auto&& v) { acc += v; };
|
||||
|
||||
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_gemm.Run(p_wei_block + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
@@ -259,7 +259,7 @@ __global__ void gridwise_implicit_gemm_convolution_2_cnhw_csrk_knhw_lds_pipeline
|
||||
{
|
||||
auto f_accum = [](auto& c, const auto&& ab) { c += ab; };
|
||||
|
||||
blockwise_gemm.Run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(s, r, 0, 0),
|
||||
blockwise_gemm.Run(p_wei_block_now + wei_csrk_block_desc.Get1dIndex(0, s, r, 0),
|
||||
p_in_block_now + s * Wi + r,
|
||||
p_out_thread,
|
||||
f_accum);
|
||||
|
||||
@@ -108,7 +108,7 @@ gridwise_implicit_gemm_convolution_2_cnhw_srck_knhw(InGlobalDesc,
|
||||
|
||||
// blockwise in copy
|
||||
// formmat is [CPerBlock,BPerBlock + BGhostRead]
|
||||
#if 1
|
||||
#if 0
|
||||
const auto blockwise_in_copy =
|
||||
Blockwise2dTensorCopy1<BlockSize,
|
||||
Float,
|
||||
|
||||
Reference in New Issue
Block a user