|
|
|
|
@@ -57,90 +57,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|
|
|
|
make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths);
|
|
|
|
|
|
|
|
|
|
#if 0
|
|
|
|
|
constexpr index_t BlockSize = 256;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerBlock = 128;
|
|
|
|
|
constexpr index_t GemmNPerBlock = 128;
|
|
|
|
|
constexpr index_t GemmKPerBlock = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerWave = 64;
|
|
|
|
|
constexpr index_t GemmNPerWave = 64;
|
|
|
|
|
constexpr index_t GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
constexpr index_t MRepeat = 1;
|
|
|
|
|
constexpr index_t NRepeat = 1;
|
|
|
|
|
|
|
|
|
|
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
|
|
|
|
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
|
|
|
|
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
|
|
|
|
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
|
|
|
|
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
|
|
|
|
#elif 0
|
|
|
|
|
// [M, N, K0, K1] = [256, 128, 4, 8]
|
|
|
|
|
constexpr index_t BlockSize = 256;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerBlock = 256;
|
|
|
|
|
constexpr index_t GemmNPerBlock = 128;
|
|
|
|
|
constexpr index_t GemmKPerBlock = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerWave = 64;
|
|
|
|
|
constexpr index_t GemmNPerWave = 64;
|
|
|
|
|
constexpr index_t GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
constexpr index_t MRepeat = 2;
|
|
|
|
|
constexpr index_t NRepeat = 1;
|
|
|
|
|
|
|
|
|
|
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
|
|
|
|
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
|
|
|
|
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
|
|
|
|
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4;
|
|
|
|
|
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
|
|
|
|
#elif 0
|
|
|
|
|
// [M, N, K0, K1] = [256, 128, 4, 8]
|
|
|
|
|
constexpr index_t BlockSize = 256;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerBlock = 256;
|
|
|
|
|
constexpr index_t GemmNPerBlock = 128;
|
|
|
|
|
constexpr index_t GemmKPerBlock = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerWave = 64;
|
|
|
|
|
constexpr index_t GemmNPerWave = 64;
|
|
|
|
|
constexpr index_t GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
constexpr index_t MRepeat = 2;
|
|
|
|
|
constexpr index_t NRepeat = 1;
|
|
|
|
|
|
|
|
|
|
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
|
|
|
|
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
|
|
|
|
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
|
|
|
|
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
|
|
|
|
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
|
|
|
|
#elif 1
|
|
|
|
|
// [M, N, K0, K1] = [256, 128, 4, 4]
|
|
|
|
|
// [M, N, K0, K1] = [256, 128, 4, 4] for fp32
|
|
|
|
|
constexpr index_t BlockSize = 256;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerBlock = 256;
|
|
|
|
|
@@ -168,7 +85,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
|
|
|
|
#elif 0
|
|
|
|
|
// [M, N, K0, K1] = [128, 128, 4, 4]
|
|
|
|
|
// [M, N, K0, K1] = [128, 128, 4, 4] for fp32
|
|
|
|
|
constexpr index_t BlockSize = 256;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerBlock = 128;
|
|
|
|
|
@@ -194,6 +111,62 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk
|
|
|
|
|
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
|
|
|
|
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
|
|
|
|
#elif 1
|
|
|
|
|
// [M, N, K0, K1] = [256, 128, 4, 8] for fp16
|
|
|
|
|
constexpr index_t BlockSize = 256;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerBlock = 256;
|
|
|
|
|
constexpr index_t GemmNPerBlock = 128;
|
|
|
|
|
constexpr index_t GemmKPerBlock = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerWave = 64;
|
|
|
|
|
constexpr index_t GemmNPerWave = 64;
|
|
|
|
|
constexpr index_t GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
constexpr index_t MRepeat = 2;
|
|
|
|
|
constexpr index_t NRepeat = 1;
|
|
|
|
|
|
|
|
|
|
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>;
|
|
|
|
|
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
|
|
|
|
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>;
|
|
|
|
|
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
|
|
|
|
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
|
|
|
|
#elif 1
|
|
|
|
|
// [M, N, K0, K1] = [128, 128, 4, 8] for fp16
|
|
|
|
|
constexpr index_t BlockSize = 256;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerBlock = 128;
|
|
|
|
|
constexpr index_t GemmNPerBlock = 128;
|
|
|
|
|
constexpr index_t GemmKPerBlock = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmMPerWave = 64;
|
|
|
|
|
constexpr index_t GemmNPerWave = 64;
|
|
|
|
|
constexpr index_t GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
constexpr index_t MRepeat = 1;
|
|
|
|
|
constexpr index_t NRepeat = 1;
|
|
|
|
|
|
|
|
|
|
using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>;
|
|
|
|
|
using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8;
|
|
|
|
|
constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8;
|
|
|
|
|
|
|
|
|
|
using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>;
|
|
|
|
|
using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1;
|
|
|
|
|
constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4;
|
|
|
|
|
|
|
|
|
|
constexpr index_t GemmCThreadTransferDstScalarPerVector = 1;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
|