diff --git a/driver/conv_bwd_data_driver_v2.cpp b/driver/conv_bwd_data_driver_v2.cpp index 3c271a37ad..cc3274c7eb 100644 --- a/driver/conv_bwd_data_driver_v2.cpp +++ b/driver/conv_bwd_data_driver_v2.cpp @@ -16,7 +16,7 @@ #include "device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 -#define USE_CONV_BWD_V4R1_XDL_NHWC 1 +#define USE_CONV_BWD_V4R1_XDL_NHWC 0 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 enum ConvBackwardDataAlgo @@ -114,7 +114,7 @@ int main(int argc, char* argv[]) const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; #endif -#if 1 +#if 0 constexpr index_t in_vector_size = 1; using in_data_t = float; using acc_data_t = float; diff --git a/driver/conv_driver_v2.cpp b/driver/conv_driver_v2.cpp index 93b13caaa4..8f972ad37e 100644 --- a/driver/conv_driver_v2.cpp +++ b/driver/conv_driver_v2.cpp @@ -24,15 +24,15 @@ #include "device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 -#define USE_CONV_FWD_V4R4_NCHW 1 -#define USE_CONV_FWD_V4R4_NHWC 1 -#define USE_CONV_FWD_V4R4R2_NHWC 1 -#define USE_CONV_FWD_V4R5_NCHW 1 -#define USE_CONV_FWD_V4R5R2_NCHW 1 +#define USE_CONV_FWD_V4R4_NCHW 0 +#define USE_CONV_FWD_V4R4_NHWC 0 +#define USE_CONV_FWD_V4R4R2_NHWC 0 +#define USE_CONV_FWD_V4R5_NCHW 0 +#define USE_CONV_FWD_V4R5R2_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 #define USE_CONV_FWD_V4R4_XDL_NCHW 1 -#define USE_CONV_FWD_V4R4R2_XDL_NHWC 1 -#define USE_CONV_FWD_V4R4R3_XDL_NHWC 1 +#define USE_CONV_FWD_V4R4R2_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4R3_XDL_NHWC 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo @@ -138,7 +138,7 @@ int main(int argc, char* argv[]) const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; #endif -#if 1 +#if 0 using in_data_t = float; using acc_data_t = float; using out_data_t = float; diff --git a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 8332798690..0711a5f262 100644 --- a/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,7 +56,35 @@ void device_dynamic_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_k const auto out_n_ho_wo_k_desc = make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); -#if 1 +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 constexpr index_t BlockSize = 256; diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp index f305cb9ae2..845095b947 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp @@ -58,7 +58,6 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw( #if 1 // cdata = 64, BlockSize = 256, 128x128x8 - // b thread copy 4x1 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlockM1 = 128; diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index e0a89d2af3..5b6a94a0b7 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -57,90 +57,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk make_dynamic_naive_tensor_descriptor_packed_v2(out_n_k_ho_wo_lengths); #if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 4] + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -168,7 +85,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -194,6 +111,62 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 64; + constexpr index_t GemmNPerWave = 64; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif diff --git a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index f423f6228e..f1a0bed7c0 100644 --- a/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,7 +56,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh const auto out_n_ho_wo_k_desc = make_dynamic_naive_tensor_descriptor_packed_v2(out_n_ho_wo_k_lengths); -#if 1 +#if 0 // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 constexpr index_t BlockSize = 256; @@ -84,7 +84,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 +#elif 0 // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 constexpr index_t BlockSize = 256; @@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nh constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 +#elif 0 // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 constexpr index_t BlockSize = 256;