From 720cf3d6b2a2bb32f4cad277bf7a1c745508e37a Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 6 Oct 2021 11:12:36 -0500 Subject: [PATCH] Tweak GEMM kernel (#38) * add parameters * tweak gemm * tweak * update conv * update script * adding bwd 1x1 * update script * adding 1x1 bwd * debugging bwd 1x1 failure * update script * update script * test * test v100 * clean up [ROCm/composable_kernel commit: b3e8d57d51300b88b591900621f71b6a1b3a7acc] --- ...lution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp | 103 ++++- .../multi_index_transform_helper.hpp | 2 +- .../gridwise_gemm_xdlops_v2r3.hpp | 173 ++++++-- composable_kernel/include/utility/config.hpp | 4 +- host/driver_offline/include/debug.hpp | 13 + ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 44 +- ...icit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp | 271 ++++++++---- ..._gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp | 389 ++++++++++++++++++ ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 29 +- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 76 +++- .../include/device_gemm_xdlops_km_kn_mn.hpp | 332 +++++++++++++-- .../include/device_gemm_xdlops_km_kn_nm.hpp | 263 ++++++++++++ .../include/device_gemm_xdlops_km_nk_mn.hpp | 332 +++++++++++++-- .../include/device_gemm_xdlops_km_nk_nm.hpp | 263 ++++++++++++ .../include/device_gemm_xdlops_mk_kn_mn.hpp | 334 +++++++++++++-- .../include/device_gemm_xdlops_mk_kn_nm.hpp | 291 +++++++++++++ .../include/device_gemm_xdlops_mk_nk_mn.hpp | 383 ++++++++++++++--- .../include/device_gemm_xdlops_mk_nk_nm.hpp | 347 ++++++++++++++++ .../include/driver_gemm_xdlops_v2r3.hpp | 20 +- .../src/conv_bwd_driver_offline.cpp | 59 ++- .../src/conv_fwd_driver_offline.cpp | 3 +- .../src/conv_wrw_driver_offline.cpp | 3 +- .../src/gemm_driver_offline.cpp | 192 +++++---- host/host_tensor/include/device.hpp | 4 + host/host_tensor/include/gemm_common.hpp | 4 + host/host_tensor/include/host_gemm.hpp | 72 ++++ script/docker-rocm4.3.1.sh | 14 + script/run.sh | 151 +++++-- 28 files changed, 3642 insertions(+), 529 deletions(-) create mode 100644 host/driver_offline/include/debug.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp create mode 100755 script/docker-rocm4.3.1.sh diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp index 9c60e8c3ac..fa78d76965 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -21,8 +21,8 @@ template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( @@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, const InRightPads& in_right_pads, - Number, - Number, + IYTilda i_ytilda, + IXTilda i_xtilda, Number) { constexpr auto I0 = Number<0>{}; @@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto GemmK1 = Number{}; - constexpr auto IYTilda = Number{}; - constexpr auto IXTilda = Number{}; + constexpr auto GemmK1 = Number{}; const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); @@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); - const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda); const auto K1 = GemmK1; const auto K0 = K / K1; @@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilda), - make_freeze_transform(IXTilda), + make_freeze_transform(i_ytilda), + make_freeze_transform(i_xtilda), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_freeze_transform(IYTilda), + make_freeze_transform(i_ytilda), make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), - make_freeze_transform(IXTilda), + make_freeze_transform(i_xtilda), make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, @@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( in_gemmm_gemmn_grid_desc); } +// A: out +// B: wei +// C: in +// Number of GEMMs = 1 +// GemmM = N * Ho * Wo +// GemmN = C +// GemmK = K +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& /* wei_k_y_x_c_grid_desc */, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp index 32acceb608..9a73799173 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp @@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform( return LeftPad{low_length, left_pad}; } -template +template __host__ __device__ constexpr auto make_right_pad_transform( const LowLength& low_length, const RightPadLength& right_pad, diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index c6f491dc47..e3b0054bec 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -29,7 +29,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const AK0MK1GridDesc a_k0_m_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CBlockClusterAdaptor c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -132,7 +132,9 @@ template + bool CAccessOrderMRepeatNRepeat, + bool ABlockLdsExtraM, + bool BBlockLdsExtraN> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { static constexpr auto I0 = Number<0>{}; @@ -152,14 +154,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -171,29 +193,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ __device__ static constexpr bool CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) + const CMNGridDesc& c_m_n_grid_desc, + index_t M01, + index_t N01) { - // TODO: turn on this static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && (NPerBlock % (NRepeat * NPerXDL)) == 0, "Invalid tuning param!"); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0); + return true; } __host__ __device__ static constexpr index_t @@ -212,11 +250,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { constexpr auto max_lds_align = K1; - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); -#elif 1 - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))), - make_tuple(Sequence<1, 0>{}), - make_tuple(Sequence<0>{})); -#endif + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); return c_blockid_to_m0_n0_block_cluster_adaptor; } using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1)); __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -296,14 +367,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // A matrix blockwise copy auto a_blockwise_copy = diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index c229162d9b..5ee4bb9c64 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -90,8 +90,8 @@ #endif // pass tensor descriptor by value or void* -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 // merge transformation use magic number division #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 diff --git a/host/driver_offline/include/debug.hpp b/host/driver_offline/include/debug.hpp new file mode 100644 index 0000000000..72fd0763ba --- /dev/null +++ b/host/driver_offline/include/debug.hpp @@ -0,0 +1,13 @@ +#ifndef DEBUG_HPP +#define DEBUG_HPP + +namespace debug { +namespace debug_driver_gemm_xdlops_v2r3 { + +// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping +static ck::index_t M01 = 1; +static ck::index_t N01 = 1; + +} // namespace debug_driver_gemm_xdlops_v2r3 +} // namespace debug +#endif diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 8f49473563..b5ff1db296 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -48,8 +48,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); -#if 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 +#if 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -76,7 +76,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 +#elif 0 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -105,7 +105,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; #elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; #elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -159,34 +159,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; #endif @@ -294,13 +266,17 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), wei_gemmk0_gemmm_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmn_grid_desc, + debug_driver_gemm_xdlops_v2r3::M01, + debug_driver_gemm_xdlops_v2r3::N01, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks, in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 2cbae2daf3..28d6226f1b 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -159,25 +159,93 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif - const auto descs = - transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc, - wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - I0, - I0, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 @@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: + // gemmk1 constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 @@ -215,7 +284,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - //clang-format on + // clang-format on constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; @@ -225,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + float ave_time = 0; + + for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda) + { + for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) + { + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( + out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + i_ytilda, + i_xtilda, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0); + + if(GemmK0 != 0) + { + ave_time += driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy #if 0 - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, #else - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, #endif - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - true // CAccessOrderMRepeatNRepeat - >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + true, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + } + } + } { const auto N = out_n_ho_wo_k_lengths[I0]; diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp new file mode 100644 index 0000000000..d6955ec000 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp @@ -0,0 +1,389 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations&, + const InLeftPads&, + const InRightPads&, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0>{}, // 1+: gemmm + Sequence<0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0>{}, // 1-: gemmm + Sequence<0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0>{}, // 1+: gemmn + Sequence<0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0>{}, // 1-: Gemmn + Sequence<0, 0, 0>{})); // 2-: Gemmk1 + + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + // clang-format on + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( + out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy +#if 0 + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, +#else + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, +#endif + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + true, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index e97bc9c1c7..b8ecfb4be9 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -203,18 +203,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false>(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN + >(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); float perf = static_cast(calculate_convolution_flops( in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 52432664de..01e5c57ab4 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -49,7 +49,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -77,7 +77,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -133,7 +133,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -160,8 +160,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -189,7 +189,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif @@ -316,13 +372,17 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), in_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, in_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp index d9169649e6..c44aa7d9a2 100644 --- a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp @@ -4,16 +4,8 @@ #include "host_tensor.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template -void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, - const BDesc& b_k_n_grid_desc, - const CDesc& c_m_n_grid_desc, - const Tensor& a_k_m, +template +void device_gemm_xdlops_km_kn_mn(const Tensor& a_k_m, const Tensor& b_k_n, Tensor& c_m_n, ck::index_t nrepeat) @@ -22,9 +14,6 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, std::cout << __func__ << std::endl; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); @@ -60,9 +49,121 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 256; @@ -88,46 +189,185 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_k_m_grid_desc.GetLength(I0); - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_k_n.mDesc.GetLengths()[1]; constexpr auto K1Number = Number{}; const auto K0 = K / K1Number; const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_k_m_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); const auto b_k0_n_k1_grid_desc = - transform_tensor_descriptor(b_k_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 - constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 @@ -147,9 +387,9 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; for(index_t i = 0; i < 5; ++i) { @@ -194,13 +434,17 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, a_k0_m_k1_grid_step_hacks, b_k0_n_k1_grid_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp new file mode 100644 index 0000000000..abaaf32113 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp @@ -0,0 +1,263 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_kn_nm(const Tensor& a_k_m, + const Tensor& b_k_n, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_k_n.mDesc.GetLengths()[1]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp index 90e258d581..0a97d361d4 100644 --- a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp @@ -4,16 +4,8 @@ #include "host_tensor.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template -void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, - const BDesc& b_n_k_grid_desc, - const CDesc& c_m_n_grid_desc, - const Tensor& a_k_m, +template +void device_gemm_xdlops_km_nk_mn(const Tensor& a_k_m, const Tensor& b_n_k, Tensor& c_m_n, ck::index_t nrepeat) @@ -22,9 +14,6 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, std::cout << __func__ << std::endl; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); @@ -60,9 +49,121 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 256; @@ -88,46 +189,185 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_k_m_grid_desc.GetLength(I0); - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_n_k_grid_desc.GetLength(I0); + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_n_k.mDesc.GetLengths()[0]; constexpr auto K1Number = Number{}; const auto K0 = K / K1Number; const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_k_m_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); const auto b_k0_n_k1_grid_desc = - transform_tensor_descriptor(b_n_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_unmerge_transform(make_tuple(K0, K1Number))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 - constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 @@ -147,9 +387,9 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; for(index_t i = 0; i < 5; ++i) { @@ -194,13 +434,17 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), static_cast(b_n_k_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, a_k0_m_k1_grid_step_hacks, b_k0_n_k1_grid_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp new file mode 100644 index 0000000000..d51caa3847 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp @@ -0,0 +1,263 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_nk_nm(const Tensor& a_k_m, + const Tensor& b_n_k, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_n_k.mDesc.GetLengths()[0]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp index ab235d97e7..30ede2517b 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -4,16 +4,8 @@ #include "host_tensor.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template -void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, - const BDesc& b_k_n_grid_desc, - const CDesc& c_m_n_grid_desc, - const Tensor& a_m_k, +template +void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, const Tensor& b_k_n, Tensor& c_m_n, ck::index_t nrepeat) @@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, std::cout << __func__ << std::endl; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); @@ -33,8 +22,148 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, b_k_n_device_buf.ToDevice(b_k_n.mData.data()); c_m_n_device_buf.ToDevice(c_m_n.mData.data()); -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 256; @@ -88,46 +217,157 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_m_k_grid_desc.GetLength(I1); - const auto M = a_m_k_grid_desc.GetLength(I0); - const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_k_n.mDesc.GetLengths()[1]; constexpr auto K1Number = Number{}; const auto K0 = K / K1Number; const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_m_k_grid_desc, - make_tuple(make_pass_through_transform(M), - make_unmerge_transform(make_tuple(K0, K1Number))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); const auto b_k0_n_k1_grid_desc = - transform_tensor_descriptor(b_k_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 - constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 @@ -147,9 +387,9 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; for(index_t i = 0; i < 5; ++i) { @@ -194,13 +434,17 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, a_k0_m_k1_grid_step_hacks, b_k0_n_k1_grid_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp new file mode 100644 index 0000000000..58ac3880d6 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp @@ -0,0 +1,291 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_kn_nm(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_k_n.mDesc.GetLengths()[1]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp index c68442d127..e99d570413 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp @@ -4,16 +4,8 @@ #include "host_tensor.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template -void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, - const BDesc& b_n_k_grid_desc, - const CDesc& c_m_n_grid_desc, - const Tensor& a_m_k, +template +void device_gemm_xdlops_mk_nk_mn(const Tensor& a_m_k, const Tensor& b_n_k, Tensor& c_m_n, ck::index_t nrepeat) @@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, std::cout << __func__ << std::endl; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); @@ -34,6 +23,34 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, c_m_n_device_buf.ToDevice(c_m_n.mData.data()); #if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 constexpr index_t BlockSize = 256; @@ -60,9 +77,93 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 256; @@ -90,7 +191,7 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 128; @@ -117,8 +218,36 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 128; @@ -144,46 +273,131 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_m_k_grid_desc.GetLength(I1); - const auto M = a_m_k_grid_desc.GetLength(I0); - const auto N = b_n_k_grid_desc.GetLength(I0); + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_n_k.mDesc.GetLengths()[0]; constexpr auto K1Number = Number{}; const auto K0 = K / K1Number; +#if 1 + // non-padded GEMM const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_m_k_grid_desc, - make_tuple(make_pass_through_transform(M), - make_unmerge_transform(make_tuple(K0, K1Number))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); const auto b_k0_n_k1_grid_desc = - transform_tensor_descriptor(b_n_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_unmerge_transform(make_tuple(K0, K1Number))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 - constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 @@ -203,9 +417,80 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; +#else + // padded GEMM + const auto a_k0_m_k1_grid_desc_tmp = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp, + make_tuple(make_pass_through_transform(K0), + make_right_pad_transform(M, MRightPad), + make_pass_through_transform(K1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc_tmp, + make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; +#endif for(index_t i = 0; i < 5; ++i) { @@ -250,13 +535,17 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_n_k_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, a_k0_m_k1_grid_step_hacks, b_k0_n_k1_grid_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp new file mode 100644 index 0000000000..a12cf0733a --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp @@ -0,0 +1,347 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_nk_nm(const Tensor& a_m_k, + const Tensor& b_n_k, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_n_k.mDesc.GetLengths()[0]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index 2bf8adba84..91ea24f947 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -1,5 +1,5 @@ -#ifndef DRIVER_GEMM_XDLOPS_V2R3 -#define DRIVER_GEMM_XDLOPS_V2R3 +#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP +#define DRIVER_GEMM_XDLOPS_V2R3_HPP #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -46,13 +46,17 @@ template + bool CAccessOrderMRepeatNRepeat, + bool ABlockLdsAddExtraM, + bool BBlockLdsAddExtraN> __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, const FloatAB* p_b_grid, FloatC* p_c_grid, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, + ck::index_t M01, + ck::index_t N01, AGridStepHacks, BGridStepHacks, CGridStepHacks, @@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, CGridStepHacks, AGridMoveSliceWindowStepHacks, BGridMoveSliceWindowStepHacks, - CAccessOrderMRepeatNRepeat>; + CAccessOrderMRepeatNRepeat, + ABlockLdsAddExtraM, + BBlockLdsAddExtraN>; { std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " @@ -123,7 +129,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; } - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + if(!GridwiseGemm::CheckValidity( + a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); @@ -134,7 +141,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + const auto c_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index 4e93ada859..366b5dffbc 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -14,15 +15,16 @@ #include "device_tensor.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp" #define USE_MODE 1 -#define USE_CONV_BWD_V4R1_XDL_NHWC 1 +#define USE_CONV_BWD_V4R1_XDL_NHWC 0 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 enum ConvBackwardDataAlgo { - V4R1XDLNHWC, - V4R1R2XDLNHWC, + V4R1XDLNHWC, // 0 + V4R1R2XDLNHWC, // 1 }; int main(int argc, char* argv[]) @@ -280,20 +282,43 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nhwc(); - device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); + if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 && + in_right_pad_w == 0) + { + device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1< + in_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } + else + { +#if 1 + device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); +#endif + } } #endif diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 34d7247f3c..48eba2b372 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -24,7 +25,7 @@ #define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp index 13c73abf30..310dbfe1eb 100644 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -111,7 +112,7 @@ int main(int argc, char* argv[]) constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif -#if 1 +#if 0 using in_data_t = float; using acc_data_t = float; using out_data_t = float; diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp index 42c69ff6a2..e60b4905ae 100644 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -16,11 +17,19 @@ #include "device_gemm_xdlops_mk_nk_mn.hpp" #include "device_gemm_xdlops_km_kn_mn.hpp" #include "device_gemm_xdlops_km_nk_mn.hpp" +#include "device_gemm_xdlops_mk_kn_nm.hpp" +#include "device_gemm_xdlops_mk_nk_nm.hpp" +#include "device_gemm_xdlops_km_kn_nm.hpp" +#include "device_gemm_xdlops_km_nk_nm.hpp" #define USE_GEMM_XDL_MK_KN_MN 1 #define USE_GEMM_XDL_MK_NK_MN 1 #define USE_GEMM_XDL_KM_KN_MN 1 #define USE_GEMM_XDL_KM_NK_MN 1 +#define USE_GEMM_XDL_MK_KN_NM 0 +#define USE_GEMM_XDL_MK_NK_NM 0 +#define USE_GEMM_XDL_KM_KN_NM 0 +#define USE_GEMM_XDL_KM_NK_NM 0 enum GemmAlgo { @@ -28,21 +37,21 @@ enum GemmAlgo Xdl_MK_NK_MN, // 1 Xdl_KM_KN_MN, // 2 Xdl_KM_NK_MN, // 3 + Xdl_MK_KN_NM, // 4 + Xdl_MK_NK_NM, // 5 + Xdl_KM_KN_NM, // 6 + Xdl_KM_NK_NM, // 7 }; int main(int argc, char* argv[]) { using namespace ck; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - // dynamic mode - if(argc != 10) + if(argc != 12) { printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: M, N, K\n"); + printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n"); exit(1); } @@ -57,6 +66,9 @@ int main(int argc, char* argv[]) const index_t N = std::stoi(argv[8]); const index_t K = std::stoi(argv[9]); + debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]); + debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]); + #if 0 using ab_data_t = float; using acc_data_t = float; @@ -74,69 +86,44 @@ int main(int argc, char* argv[]) std::vector a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); std::vector a_strides_host(2), b_strides_host(2), c_strides_host(2); - if(layout == GemmMatrixLayout::MK_KN_MN) + // A + if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN || + layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM) { a_lengths_host[0] = static_cast(M); a_lengths_host[1] = static_cast(K); a_strides_host[0] = static_cast(K); a_strides_host[1] = static_cast(1); - - b_lengths_host[0] = static_cast(K); - b_lengths_host[1] = static_cast(N); - b_strides_host[0] = static_cast(N); - b_strides_host[1] = static_cast(1); - - c_lengths_host[0] = static_cast(M); - c_lengths_host[1] = static_cast(N); - c_strides_host[0] = static_cast(N); - c_strides_host[1] = static_cast(1); } - else if(layout == GemmMatrixLayout::MK_NK_MN) - { - a_lengths_host[0] = static_cast(M); - a_lengths_host[1] = static_cast(K); - a_strides_host[0] = static_cast(K); - a_strides_host[1] = static_cast(1); - - b_lengths_host[0] = static_cast(N); - b_lengths_host[1] = static_cast(K); - b_strides_host[0] = static_cast(K); - b_strides_host[1] = static_cast(1); - - c_lengths_host[0] = static_cast(M); - c_lengths_host[1] = static_cast(N); - c_strides_host[0] = static_cast(N); - c_strides_host[1] = static_cast(1); - } - else if(layout == GemmMatrixLayout::KM_KN_MN) + else { a_lengths_host[0] = static_cast(K); a_lengths_host[1] = static_cast(M); a_strides_host[0] = static_cast(M); a_strides_host[1] = static_cast(1); - - b_lengths_host[0] = static_cast(K); - b_lengths_host[1] = static_cast(N); - b_strides_host[0] = static_cast(N); - b_strides_host[1] = static_cast(1); - - c_lengths_host[0] = static_cast(M); - c_lengths_host[1] = static_cast(N); - c_strides_host[0] = static_cast(N); - c_strides_host[1] = static_cast(1); } - else if(layout == GemmMatrixLayout::KM_NK_MN) - { - a_lengths_host[0] = static_cast(K); - a_lengths_host[1] = static_cast(M); - a_strides_host[0] = static_cast(M); - a_strides_host[1] = static_cast(1); + // B + if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN || + layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM) + { b_lengths_host[0] = static_cast(N); b_lengths_host[1] = static_cast(K); b_strides_host[0] = static_cast(K); b_strides_host[1] = static_cast(1); + } + else + { + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + } + // C + if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN || + layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN) + { c_lengths_host[0] = static_cast(M); c_lengths_host[1] = static_cast(N); c_strides_host[0] = static_cast(N); @@ -144,7 +131,10 @@ int main(int argc, char* argv[]) } else { - std::runtime_error("wrong! not implemented"); + c_lengths_host[0] = static_cast(N); + c_lengths_host[1] = static_cast(M); + c_strides_host[0] = static_cast(M); + c_strides_host[1] = static_cast(1); } Tensor a(a_lengths_host, a_strides_host); @@ -185,38 +175,6 @@ int main(int argc, char* argv[]) b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } - auto f_make_for_device_mk_kn_mn = [&]() { - const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); - const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); - const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); - - return make_tuple(a_desc, b_desc, c_desc); - }; - - auto f_make_for_device_mk_nk_mn = [&]() { - const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); - const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); - const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); - - return make_tuple(a_desc, b_desc, c_desc); - }; - - auto f_make_for_device_km_kn_mn = [&]() { - const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); - const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); - const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); - - return make_tuple(a_desc, b_desc, c_desc); - }; - - auto f_make_for_device_km_nk_mn = [&]() { - const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); - const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); - const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); - - return make_tuple(a_desc, b_desc, c_desc); - }; - #if USE_GEMM_XDL_MK_KN_MN if(algo == GemmAlgo::Xdl_MK_KN_MN) { @@ -225,10 +183,7 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! layout"); } - const auto descs = f_make_for_device_mk_kn_mn(); - - device_gemm_xdlops_mk_kn_mn( - descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + device_gemm_xdlops_mk_kn_mn(a, b, c_device, nrepeat); } #endif @@ -240,10 +195,7 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! layout"); } - const auto descs = f_make_for_device_mk_nk_mn(); - - device_gemm_xdlops_mk_nk_mn( - descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + device_gemm_xdlops_mk_nk_mn(a, b, c_device, nrepeat); } #endif @@ -255,10 +207,7 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! layout"); } - const auto descs = f_make_for_device_km_kn_mn(); - - device_gemm_xdlops_km_kn_mn( - descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + device_gemm_xdlops_km_kn_mn(a, b, c_device, nrepeat); } #endif @@ -270,10 +219,55 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! layout"); } - const auto descs = f_make_for_device_km_nk_mn(); + device_gemm_xdlops_km_nk_mn(a, b, c_device, nrepeat); + } +#endif - device_gemm_xdlops_km_nk_mn( - descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); +#if USE_GEMM_XDL_MK_KN_NM + if(algo == GemmAlgo::Xdl_MK_KN_NM) + { + if(layout != GemmMatrixLayout::MK_KN_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_mk_kn_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_MK_NK_NM + if(algo == GemmAlgo::Xdl_MK_NK_NM) + { + if(layout != GemmMatrixLayout::MK_NK_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_mk_nk_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_KN_NM + if(algo == GemmAlgo::Xdl_KM_KN_NM) + { + if(layout != GemmMatrixLayout::KM_KN_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_km_kn_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_NK_NM + if(algo == GemmAlgo::Xdl_KM_NK_NM) + { + if(layout != GemmMatrixLayout::KM_NK_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_km_nk_nm(a, b, c_device, nrepeat); } #endif diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp index e2cba94100..9b66f24f7a 100644 --- a/host/host_tensor/include/device.hpp +++ b/host/host_tensor/include/device.hpp @@ -2,6 +2,8 @@ #define DEVICE_HPP #include +#include +#include #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -74,6 +76,8 @@ float launch_and_time_kernel( timer.End(); + // std::this_thread::sleep_for (std::chrono::microseconds(10)); + return timer.GetElapsedTime() / nrepeat; } diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp index f0f35a78b9..f6c0d6f930 100644 --- a/host/host_tensor/include/gemm_common.hpp +++ b/host/host_tensor/include/gemm_common.hpp @@ -7,6 +7,10 @@ enum GemmMatrixLayout MK_NK_MN, // 1 KM_KN_MN, // 2 KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 }; #endif diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp index 97cf245054..c582a34258 100644 --- a/host/host_tensor/include/host_gemm.hpp +++ b/host/host_tensor/include/host_gemm.hpp @@ -80,6 +80,78 @@ void host_gemm(const Tensor& a, make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); } + else if(layout == GemmMatrixLayout::MK_KN_NM) + { + auto f_mk_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_NM) + { + auto f_mk_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_NM) + { + auto f_km_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_NM) + { + auto f_km_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } else { throw std::runtime_error("wrong! not supported layout"); diff --git a/script/docker-rocm4.3.1.sh b/script/docker-rocm4.3.1.sh new file mode 100755 index 0000000000..48cb675b69 --- /dev/null +++ b/script/docker-rocm4.3.1.sh @@ -0,0 +1,14 @@ +WORKSPACE=$1 +echo "workspace: " $WORKSPACE + +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v $WORKSPACE:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash + +#--network host \ diff --git a/script/run.sh b/script/run.sh index 3b383fcf3a..1ff56b2295 100755 --- a/script/run.sh +++ b/script/run.sh @@ -4,24 +4,12 @@ export ROCR_VISIBLE_DEVICE=0 export GPU_DEVICE_ORDINAL=0 -## Boost - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH - -## Compiling -#export OLC_DEBUG_HIP_VERBOSE=1 -#export OLC_DEBUG_HIP_DUMP=1 -#export OLC_DEBUG_SAVE_TEMP_DIR=1 - -#rm -rf /root/_hip_binary_kernels_/ -#rm -rf /tmp/olCompile* - -#make -j conv_fwd_driver_offline + make -j conv_fwd_driver_offline #make -j conv_bwd_driver_offline #make -j conv_wrw_driver_offline -#make -j conv_fwd_driver_online - - make -j gemm_driver_offline +#make -j gemm_driver_offline +DRIVER="./host/driver_offline/conv_fwd_driver_offline" LAYOUT=$1 ALGO=$2 VERIFY=$3 @@ -29,30 +17,121 @@ INIT=$4 LOG=$5 REPEAT=$6 -################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#M01=$7 +#N01=$8 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 + KBATCH=$7 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_wrw_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 -#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -################################################ layout algo verify init log repeat M___ N___ K___ -#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 - ./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 + +######### layout algo verify init log repeat M___ N___ K___ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 + +# Resnet50 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 + +# 256x128x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH + + + +# 128x128x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH + + +# 128x64x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 + +# 64x128x32 c64 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH + +# 64x64x32 c32 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448