diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..949f044b7d --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index 42a5a875b7..1a25e99f3b 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division } }; +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v3_division_mod +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v3_division_mod() = default; + + __host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + const index_t tmp2 = idx_low[i]; + idx_low(i) = tmp / this->low_lengths_scan_[i]; + idx_diff_low(i) = idx_low[i] - tmp2; + tmp %= this->low_lengths_scan_[i]; + }); + + const index_t tmp2 = idx_low[INm1]; + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_low[INm1] - tmp2; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + template struct UnMerge { diff --git a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp index 6d4e01888b..32acceb608 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp @@ -52,22 +52,36 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng template __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) { -#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION +#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION + return make_merge_transform_v2_magic_division(low_lengths); +#else + return make_merge_transform_v1_carry_check(low_lengths); +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v1_carry_check(const LowLengths& low_lengths) +{ return Merge_v1_carry_check{low_lengths}; -#else -#if 1 - return Merge_v2_magic_division{low_lengths}; -#else - return Merge_v2r2_magic_division{low_lengths}; -#endif -#endif } template __host__ __device__ constexpr auto make_merge_transform_v2_magic_division(const LowLengths& low_lengths) { +#if 1 return Merge_v2_magic_division{low_lengths}; +#else + return Merge_v2r2_magic_division{low_lengths}; +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v3_division_mod(const LowLengths& low_lengths) +{ + return Merge_v3_division_mod{low_lengths}; } template diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp index 3b647e433a..50a8088bba 100644 --- a/composable_kernel/include/tensor_description/tensor_adaptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -189,8 +189,7 @@ struct TensorAdaptor bool is_known = true; static_for<0, Transforms::Size(), 1>{}([&](auto i) { - is_known &= - remove_cv_t>::IsKnownAtCompileTime(); + is_known &= remove_cvref_t::IsKnownAtCompileTime(); }); return is_known && is_known_at_compile_time::value; diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index a6a57ba63b..8f6a5a3e43 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -185,8 +185,7 @@ struct TensorDescriptor bool is_known = true; static_for<0, Transforms::Size(), 1>{}([&](auto i) { - is_known &= - remove_cv_t>::IsKnownAtCompileTime(); + is_known &= remove_cvref_t::IsKnownAtCompileTime(); }); return is_known && is_known_at_compile_time::value && @@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& template using TensorCoordinate_t = decltype(make_tensor_coordinate( - TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); template using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( - TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp index 03f889649e..5cc2f2393e 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp @@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 const BThreadBuffer& b_thread_buf, CThreadBuffer& c_thread_buf) const { - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index ee6a0b7427..a8236737df 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -4,21 +4,21 @@ #include "common_header.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "xdlops_gemm.hpp" +#include "tensor_adaptor.hpp" namespace ck { template -struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { - - using CIndex = MultiIndex<2>; - static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -26,111 +26,165 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 static constexpr index_t WaveSize = 64; - static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t KPerBlock = K0; - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; - static constexpr index_t MWaves = M1 / MPerWave; - static constexpr index_t NWaves = N1 / NPerWave; + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static constexpr index_t MRepeat = M0; - static constexpr index_t NRepeat = N0; + __device__ static auto GetWaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); - __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); - __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } - - __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } __device__ static auto CalculateAThreadOriginDataIndex() { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; + const auto wave_idx = GetWaveIdx(); - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, m_offset, 0); - } - else - { - const index_t m_offset = waveId_m * MPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, m_offset, 0); - } + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0); } __device__ static auto CalculateBThreadOriginDataIndex() { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_n = waveId % NWaves; + const auto wave_idx = GetWaveIdx(); - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, n_offset, 0); - } - else - { - const index_t n_offset = waveId_n * NPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, n_offset, 0); - } + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0); } template - __device__ static CIndex + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { + const auto wave_idx = GetWaveIdx(); - const index_t waveId = get_thread_local_1d_id() / WaveSize; + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); - const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; - const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); - return CIndex{m_offset, n_offset}; + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); } - __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() - : a_thread_copy_{CalculateAThreadOriginDataIndex()}, - b_thread_copy_{CalculateBThreadOriginDataIndex()} + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); + static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0), + "wrong! K0 dimension not consistent"); - static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2), "wrong! K1 dimension not consistent"); static_assert(BlockSize == MWaves * NWaves * WaveSize, "BlockSize != MWaves * NWaves * WaveSize\n"); - static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - - static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); } + __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() + { + constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc); + } + + template + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); + } + + __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor(); + static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor(); + template __device__ void Run(const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, @@ -141,49 +195,48 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + vector_type a_thread_vec; - vector_type a_thread_vec; + vector_type b_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { + static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) { // read A - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0, I0), + a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, + make_tuple(k0, I0, I0, I0, I0), a_block_buf, a_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, I0, I0, I0, I0), a_thread_buf); // read B - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0, I0), + b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, + make_tuple(k0, I0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, I0, I0, I0, I0), b_thread_buf); - using mfma_input_type = - typename vector_type::type; - - static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { - a_thread_vec.template AsType()(Number{}) = a_thread_buf[Number{}]; - }); - - static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { - b_thread_vec.template AsType()(Number{}) = b_thread_buf[Number{}]; - }); + using mfma_input_type = typename vector_type::type; static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf); + static_for<0, K1, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + }); + + static_for<0, K1, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf); }); }); }); @@ -191,333 +244,38 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 private: // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, I1, Number{})); // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, I1, Number{})); - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, MRepeat, 1, 1, K1>, + Sequence<0, 1, 2, 3, 4>, + 4, K1, 1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, NRepeat, 1, 1, K1>, + Sequence<0, 1, 2, 3, 4>, + 4, K1, 1>; - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -template -struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline -{ - - using CIndex = MultiIndex<2>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto xdlops_gemm = XdlopsGemm{}; - - static constexpr index_t WaveSize = 64; - - static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); - - static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); - - static constexpr index_t MWaves = M1 / MPerWave; - static constexpr index_t NWaves = N1 / NPerWave; - - static constexpr index_t MRepeat = M0; - static constexpr index_t NRepeat = N0; - - __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } - - __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } - - __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; - - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, m_offset, 0); - } - else - { - const index_t m_offset = waveId_m * MPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, m_offset, 0); - } - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_n = waveId % NWaves; - - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, n_offset, 0); - } - else - { - const index_t n_offset = waveId_n * NPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, n_offset, 0); - } - } - - template - __device__ static CIndex - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) - { - - const index_t waveId = get_thread_local_1d_id() / WaveSize; - - const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - - const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; - - const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; - const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; - - return CIndex{m_offset, n_offset}; - } - - __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() - : a_thread_copy_{CalculateAThreadOriginDataIndex()}, - b_thread_copy_{CalculateBThreadOriginDataIndex()} - { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); - - static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), - "wrong! K1 dimension not consistent"); - - static_assert(BlockSize == MWaves * NWaves * WaveSize, - "BlockSize != MWaves * NWaves * WaveSize\n"); - - static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - - static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); - } - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I1, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I1, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - static_for{}([&](auto k) { - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I1, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I1, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - }); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - } - - private: - // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - - // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 207f73072f..3e4d74e9d8 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -18,7 +18,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -29,7 +29,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const AK0MK1GridDesc a_k0_m_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc, const CBlockClusterAdaptor c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -43,7 +43,7 @@ __global__ void p_shared_block, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER @@ -52,7 +52,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -63,7 +63,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const void CONSTANT* p_c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -73,8 +73,9 @@ __global__ void cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); const auto b_k0_n_k1_grid_desc = *reinterpret_cast( cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); - const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc)); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); const auto c_block_cluster_adaptor = *reinterpret_cast( cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); @@ -86,7 +87,7 @@ __global__ void p_shared_block, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); } #endif @@ -102,8 +103,8 @@ template {}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -179,14 +183,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto N = b_k0_n_k1_grid_desc.GetLength(I1); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) && K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0) && - (MPerBlock % MPerWave == 0 && NPerBlock % NPerWave == 0); + (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0); } __host__ __device__ static constexpr index_t @@ -201,29 +207,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } __host__ __device__ static constexpr auto - MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) { - constexpr auto xdlops_gemm = XdlopsGemm{}; + constexpr auto max_lds_align = K1; - constexpr auto CLayout = xdlops_gemm.GetCLayout(); + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); - constexpr auto M0 = Number{}; - constexpr auto M1 = Number{}; - constexpr auto M2 = Number{}; + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); - constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); - constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - constexpr auto N1 = Number{}; - - const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), - make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); - - return c_m0_m1_m2_n_grid_desc; + return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); } __host__ __device__ static constexpr auto @@ -253,8 +258,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return c_blockid_to_m0_n0_block_cluster_adaptor; } - using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -262,7 +267,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAB* __restrict__ p_shared_block, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CBlockClusterAdaptor& c_block_cluster_adaptor) { const auto a_grid_buf = make_dynamic_buffer( @@ -270,7 +275,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); @@ -358,50 +363,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // register // sanity check - static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && - NPerBlock % (NPerWave * NRepeat) == 0, - "wrong!"); - - constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor( - a_k0_m_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor( - b_k0_n_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - const auto blockwise_gemm = - BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; - - constexpr auto CLayout = blockwise_gemm.GetCLayout(); - - constexpr index_t BlkSize = CLayout.GetBlkSize(); - constexpr index_t NumBlks = CLayout.GetNumBlks(); - constexpr index_t NumXdlops = CLayout.GetNumXdlops(); - - static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; constexpr auto c_mr_nr_blk_desc = make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor(); + constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize(); + StaticBuffer, + vector_type, c_mr_nr_blk_desc.GetElementSpaceSize(), true> c_thread_buf; @@ -474,41 +455,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } -#if 0 // output: register to global memory { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); - constexpr index_t N0 = CLayout.N1(); - constexpr index_t N1 = CLayout.N0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number<1>{}, - Number<1>{}, - Number{}, - Number<1>{}, - Number{}, - Number<1>{})); - - StaticBuffer - c_blk_buf_; - - static_for<0, MRepeat, 1>{}([&](auto mr_i) { - static_for<0, NRepeat, 1>{}([&](auto nr_i) { - constexpr auto blk_off = - c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i)); - - static_for<0, BlkSize, 1>{}([&](auto j) { - c_blk_buf_(Number{}) = - c_thread_buf[Number{}] - .template AsType()[Number{}]; - }); - }); - }); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -521,145 +475,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; - - constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); - constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(c_m0_m1_m2_n_thread_desc), - decltype(c_m0_m1_m2_n_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_m0_m1_m2_n_grid_desc, - make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves), - n_thread_data_on_grid / (N1 * NWaves), - m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0), - n_thread_data_on_grid % (N1 * NWaves) / N1, - m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, - n_thread_data_on_grid % N1)} - .Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_blk_buf_, - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); - } -#else - { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), + Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, true>{ - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, make_multi_index(0, 0, 0, 0, - m_thread_data_on_grid / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, + m_thread_data_on_grid / (M3 * M4), + m_thread_data_on_grid % (M3 * M4) / M4, + m_thread_data_on_grid % M4, n_thread_data_on_grid)}; auto init_copy = [&](auto c_thread_idx_) { constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); return c_thread_idx_; }; auto mrepeat_plus_copy = [&](auto c_thread_idx_) { constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + mrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto nrepeat_plus_copy = [&](auto c_thread_idx_) { constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + nrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto mrepeat_minus_copy = [&](auto c_thread_idx_) { constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + mrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto nrepeat_minus_copy = [&](auto c_thread_idx_) { constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + nrepeat_step_minus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or @@ -791,7 +696,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 init_copy(make_tuple(I0, I0)); } } -#endif } }; // namespace ck diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp index a925a5cd68..8b75381026 100644 --- a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp @@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 CBuffer& c_buf, COriginIdx) { - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ CBuffer& c_buf, COriginIdx) { - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp index 015ad675fb..f6c15fd85a 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp @@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 CDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp index 0c7aa978a7..20e9a5b366 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp @@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1 static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert(is_known_at_compile_time>>::value, + static_assert(is_known_at_compile_time>::value, "wrong! OriginIdx need to be known at compile-time"); // Desc is known at compile-time - constexpr auto desc = remove_cv_t>{}; + constexpr auto desc = remove_cvref_t{}; // OriginIdx is known at compile-time constexpr auto origin_idx = to_multi_index(OriginIdx{}); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 27fd91812d..157828bf0f 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); - static_assert( - is_known_at_compile_time>>::value, - "wrong! SrcSliceOrigin need to known at compile-time"); + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); - // static_assert(is_same>, - // remove_cv_t>>::value, - //"wrong! SrcBuffer data type is wrong"); - // SrcDesc and src_slice_origin_idx are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto I0 = Number<0>{}; @@ -208,10 +203,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3 coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); // copy data from dst_vector into dst_buf - dst_buf.template Set( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector.template AsType()[Number<0>{}]); + if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) + { + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } + else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) + { + dst_buf.template AtomicAdd( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } constexpr auto move_on_dim = [&]() constexpr { @@ -411,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2 static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! DstDesc need to known at compile-time"); - static_assert( - is_known_at_compile_time>>::value, - "wrong! DstSliceOrigin need to known at compile-time"); + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); // DstDesc and dst_slice_origin_idx are known at compile-time - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cvref_t{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; constexpr auto I0 = Number<0>{}; @@ -729,9 +733,9 @@ struct ThreadwiseTensorSliceTransfer_v3 SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -886,9 +890,9 @@ struct ThreadwiseTensorSliceTransfer_v3 DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -1303,24 +1307,21 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert( - is_known_at_compile_time< - remove_cv_t>>::value && - is_known_at_compile_time>>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp index ccac4b7b44..bbdaa5fa2b 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp @@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); // tensor descriptor for src_vector constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; @@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); // tensor descriptor for dst_vector constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; @@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert( - is_known_at_compile_time< - remove_cv_t>>::value && - is_known_at_compile_time>>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index affe096ace..f945b0fdf5 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -7,21 +7,18 @@ namespace ck { -enum struct mfma_instr +enum struct MfmaInstr { - /// fp32 mfma_f32_32x32x1xf32 = 0, mfma_f32_16x16x1xf32, mfma_f32_4x4x1xf32, mfma_f32_32x32x2xf32, // k reduction mfma_f32_16x16x4xf32, // k reduction - /// fp16 mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, mfma_f32_32x32x8f16, // k reduction mfma_f32_16x16x16f16, // k reduction - /// bfp16 mfma_f32_32x32x2bf16, mfma_f32_16x16x2bf16, mfma_f32_4x4x2bf16, @@ -29,25 +26,23 @@ enum struct mfma_instr mfma_f32_16x16x8bf16, // k reduction }; -template -struct mfma_info; +template +struct mfma_type; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 1; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 2; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 4; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 1; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template // treat 4x4x1 as a single-blk 4x64 mfma template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 1; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 4; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 8; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 16; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 4; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 4; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template #if 0 template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 2; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 4; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 8; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 2; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 2; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; #endif -template -struct xdlops_info +template +struct MfmaSelector { - static constexpr auto mfma_type = mfma_info{}; + template + static constexpr auto GetMfma(); - static constexpr index_t MPerXdlops = MPerXdlops_; - static constexpr index_t NPerXdlops = NPerXdlops_; + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x2xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x8f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x16f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + +#if 0 + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } +#endif + + static constexpr auto selected_mfma = mfma_type()>{}; + + __host__ __device__ static constexpr void mfma_check() + { + static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == + selected_mfma.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks == + selected_mfma.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks || + selected_mfma.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size == + selected_mfma.m_per_blk * selected_mfma.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_mfma.is_k_reduction || + (selected_mfma.num_input_blks == selected_mfma.num_output_blks), + "is_k_reduction wrong!"); + } + + __host__ __device__ constexpr MfmaSelector() { mfma_check(); } static constexpr bool IsABroadcast() { @@ -505,186 +652,33 @@ struct xdlops_info return true; } - static constexpr bool IsKReduction() - { - return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); - } - static constexpr index_t GetKPerXdlops() { - return IsKReduction() ? mfma_type.num_input_blks : 1; + return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) * + selected_mfma.k_per_blk; } - static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } + static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } }; -template +template struct XdlopsGemm { - template - static constexpr auto GetXdlopsInfo(); - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - -#if 0 - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } -#endif + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; using CIndex = MultiIndex<2>; - __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } __device__ static constexpr index_t GetNumXdlops() { - return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + return MPerXdlops * NPerXdlops / + (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks); } __host__ __device__ constexpr XdlopsGemm() @@ -697,104 +691,142 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); - static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, - "m != num_input_blks * num_regs_blk"); - static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks || - mfma_type.num_output_blks == 1, - "incorrect num_output_blks"); - static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n, - "num_regs_blk incorrect"); + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } - static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); + template + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc) + { + const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0); + const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1); + const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2); + const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3); + + return transform_tensor_descriptor( + c_m0_n0_m1_n1_m2_n2_desc, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, + mfma_instr.num_input_blks, + mfma_instr.group_size)), + make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); } __device__ static constexpr index_t GetRegSizePerXdlops() { - return MPerXdlops * NPerXdlops / mfma_type.wave_size; + return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } - template + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert(is_same::value || is_same::value || is_same::value, "base base_type must be float, half, ushort!"); - static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); - - constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); - - static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { - constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); - constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); - - mfma_type.template run( - p_a_wave[Number{}], - p_b_wave[Number{}], - p_c_thread); + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + mfma_instr.template run( + p_a_wave[k], p_b_wave[k], p_c_thread); }); } + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) { - const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; - const index_t blk_id = laneId / mfma_type.num_threads_blk; - const index_t blk_td = laneId % mfma_type.num_threads_blk; + const auto blk_idx = GetBlkIdx(); - index_t n_offset = blk_i * mfma_type.n + blk_td; - index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size; return CIndex{m_offset, n_offset}; } - static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; - static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; - static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; - static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; + static constexpr auto mfma = MfmaSelector{}; - static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); - static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); - static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); + static constexpr auto mfma_instr = mfma.selected_mfma; - static constexpr auto GetBlkId(const index_t lane_id) + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto KPerThread = mfma.GetKPerThread(); + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() { - return lane_id / mfma_type.num_threads_blk; + return make_tuple( + Number{}, I1, Number{}, I1); } - - static constexpr auto GetBlkTd(const index_t lane_id) - { - return lane_id % mfma_type.num_threads_blk; - } - - static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; - - struct CLayout - { - __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } - __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } - __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } - __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } - - __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } - - __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } - - __device__ static constexpr index_t GetNumXdlops() - { - return MPerXdlops * NPerXdlops / - (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); - } - }; - - __host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } }; } // namespace ck diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index a54607a053..3df53bda44 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); +// atomic add +// int +__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( + int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); + +// float +__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( + float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, @@ -624,8 +640,130 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } } +template +__device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); + } + } +} + // buffer_load requires: -// 1) p_src_wave must be in global memory space +// 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template @@ -659,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, } // buffer_load requires: -// 1) p_src_wave must be in global memory space +// 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template @@ -687,8 +825,8 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, } // buffer_store requires: -// 1) p_dst_wave must be global memory -// 2) p_dst_wave to be a wavewise pointer. +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, @@ -720,5 +858,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #endif } +// buffer_atomic_add requires: +// 1) p_dst_wave must point to global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/array.hpp b/composable_kernel/include/utility/array.hpp index 7271094d39..911cefd057 100644 --- a/composable_kernel/include/utility/array.hpp +++ b/composable_kernel/include/utility/array.hpp @@ -48,7 +48,7 @@ struct Array template __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { - using data_type = remove_cv_t>; + using data_type = remove_cvref_t; return Array{{std::forward(x), std::forward(xs)...}}; } diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 521ad24d47..c229162d9b 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -85,8 +85,8 @@ #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #endif -#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 +#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #endif // pass tensor descriptor by value or void* diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index 210c493602..886737efac 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -43,18 +43,15 @@ struct DynamicBuffer __host__ __device__ constexpr T& operator()(index_t i) { return p_data_[i]; } template >>::type, - typename scalar_type>>::type>::value, - bool>::type = false> + typename enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { // X contains multiple T - constexpr index_t scalar_per_t_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); @@ -71,15 +68,14 @@ struct DynamicBuffer if constexpr(InvalidElementUseNumericalZeroValue) { - return amd_buffer_load_invalid_element_return_return_zero< - remove_cv_t>, - t_per_x>(p_data_, i, is_valid_element, element_space_size_); + return amd_buffer_load_invalid_element_return_return_zero, + t_per_x>( + p_data_, i, is_valid_element, element_space_size_); } else { - return amd_buffer_load_invalid_element_return_customized_value< - remove_cv_t>, - t_per_x>( + return amd_buffer_load_invalid_element_return_customized_value, + t_per_x>( p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); } } @@ -98,18 +94,15 @@ struct DynamicBuffer } template >>::type, - typename scalar_type>>::type>::value, - bool>::type = false> + typename enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); @@ -119,7 +112,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_ADDRESSING constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - amd_buffer_store>, t_per_x>( + amd_buffer_store, t_per_x>( x, p_data_, i, is_valid_element, element_space_size_); #else if(is_valid_element) @@ -140,70 +133,65 @@ struct DynamicBuffer // ISA, so I try to let compiler emit IR "store" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix - if constexpr(is_same>>::type, - int8_t>::value) + if constexpr(is_same>::type, int8_t>::value) { - static_assert( - (is_same>, int8_t>::value && - is_same>, int8_t>::value) || - (is_same>, int8_t>::value && - is_same>, int8x2_t>::value) || - (is_same>, int8_t>::value && - is_same>, int8x4_t>::value) || - (is_same>, int8x4_t>::value && - is_same>, int8x4_t>::value) || - (is_same>, int8x8_t>::value && - is_same>, int8x8_t>::value) || - (is_same>, int8x16_t>::value && - is_same>, int8x16_t>::value), - "wrong! not implemented for this combination, please add " - "implementation"); + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); - if constexpr(is_same>, int8_t>::value && - is_same>, int8_t>::value) + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, int8_t>::value && - is_same>, int8x2_t>::value) + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, int8_t>::value && - is_same>, int8x4_t>::value) + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x4_t>::value && - is_same>, int8x4_t>::value) + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x8_t>::value && - is_same>, int8x8_t>::value) + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x16_t>::value && - is_same>, int8x16_t>::value) + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix @@ -227,6 +215,35 @@ struct DynamicBuffer } } + template >::type, + typename scalar_type>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; + + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X need to be multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ADDRESSING + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); +#else + if(is_valid_element) + { + atomicAdd(&p_data_[i], x); + } +#endif + } + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } diff --git a/composable_kernel/include/utility/magic_division.hpp b/composable_kernel/include/utility/magic_division.hpp index b7489016e9..612aceea2a 100644 --- a/composable_kernel/include/utility/magic_division.hpp +++ b/composable_kernel/include/utility/magic_division.hpp @@ -114,12 +114,11 @@ struct MagicDivision __host__ __device__ static constexpr uint32_t DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) { - uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32; + uint32_t tmp = __umulhi(dividend, multiplier); return (tmp + dividend) >> shift; } -#if 1 // debug - // HACK: magic division for int32_t + // magic division for int32_t // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // non-negative for result to be correct // TODO: figure out how to do magic number divison for int32_t as dividended @@ -127,27 +126,9 @@ struct MagicDivision DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) { uint32_t dividend_u32 = as_type(dividend_i32); - uint32_t tmp = - (static_cast(dividend_u32) * static_cast(multiplier)) >> 32; + uint32_t tmp = __umulhi(dividend_u32, multiplier); return (tmp + dividend_u32) >> shift; } -#else - // the inline ASM is producing wrong result - __host__ __device__ static int32_t - DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) - { - uint32_t r; - asm volatile("\n \ - v_mul_hi_u32 %0, %1, %2 \n \ - v_add_u32_e32 %0, %1, %0 \n \ - v_lshrrev_b32_e32 %0, %3, %0 \n \ - " - : "=v"(r) - : "v"(as_type(dividend_i32)), "s"(multiplier), "s"(shift)); - - return as_type(r); - } -#endif }; } // namespace ck diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index ee96a8b435..70f4d77d87 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl __host__ __device__ constexpr auto make_tuple(Xs&&... xs) { - return Tuple>...>(std::forward(xs)...); + return Tuple...>(std::forward(xs)...); } } // namespace ck diff --git a/composable_kernel/include/utility/tuple_helper.hpp b/composable_kernel/include/utility/tuple_helper.hpp index 9499a3596c..55a79d2594 100644 --- a/composable_kernel/include/utility/tuple_helper.hpp +++ b/composable_kernel/include/utility/tuple_helper.hpp @@ -14,9 +14,7 @@ struct is_known_at_compile_time> return container_reduce( Tuple{}, [](auto x, bool r) { - return is_known_at_compile_time< - remove_cv_t>>::value & - r; + return is_known_at_compile_time>::value & r; }, true); } diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp index c1208ac3cb..71239e0ecc 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -374,13 +374,8 @@ extern "C" __global__ void CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, CGridBlockCluster_BlockId_To_GM10_GN10{})); - const auto desc_tuple = *reinterpret_cast( -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wold-style-cast" - // TODO: how to cast? - (const void*)p_desc_tuple -#pragma clang diagnostic pop - ); + const auto desc_tuple = + *reinterpret_cast(cast_pointer_to_generic_address_space(p_desc_tuple)); const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index fec11e99af..a3b3613293 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -13,9 +13,15 @@ include_directories(BEFORE set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) +set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) +set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) +add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) +add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE}) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) +target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) +target_link_libraries(gemm_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 7bd82bf6d5..8f49473563 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,9 +56,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -84,9 +84,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -112,9 +112,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -140,9 +140,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 4; @@ -168,9 +168,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -208,40 +208,42 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + //clang-format on constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; @@ -263,8 +265,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, + GemmMPerXDL, + GemmNPerXDL, GemmK1, MRepeat, NRepeat, @@ -289,7 +291,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat @@ -301,7 +303,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( in_gemmm_gemmn_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_m1_m2_n_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 0ebf8571f4..2cbae2daf3 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -195,25 +195,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + //clang-format on constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; @@ -265,7 +267,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk GemmCThreadTransferDstScalarPerVector, decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), true // CAccessOrderMRepeatNRepeat @@ -277,7 +279,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk in_gemmm_gemmn_grid_desc, out_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_m1_m2_n_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e97bc9c1c7 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,228 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 1, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false>(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 4a9d01081c..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,280 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#endif - - const auto descs = -#if 1 - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad -#else - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 -#endif - ( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - for(index_t i = 0; i < 5; ++i) - { -#if 0 - float ave_time = launch_kernel_gemm_xdlops_v1 -#else - float ave_time = launch_kernel_gemm_xdlops_v2 -#endif - , - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<1, 0, 2>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy, which will be fused - // with MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(descs[I4]), - decltype(descs[I5]), - decltype(descs[I6]), - decltype(descs[I7]), - decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - descs[I0], - descs[I1], - descs[I2], - descs[I3], - descs[I4], - descs[I5], - descs[I6], - descs[I7], - descs[I8], - nrepeat); - - float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 695ffeeb36..d65ecadb4d 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -47,7 +47,35 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); -#if 1 +#if 0 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -92,36 +120,39 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; @@ -169,7 +200,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), @@ -180,7 +211,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( out_gemmm_gemmn_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, + out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 141a326574..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,229 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r2.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1>, - 2, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 692751bfb3..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,302 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 7067291c8a..52432664de 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,8 +56,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 4; @@ -84,9 +84,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -112,9 +112,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 4; @@ -140,9 +140,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -168,9 +168,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 4; @@ -196,9 +196,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -249,23 +249,23 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0>{}), // 7+: N1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; @@ -287,8 +287,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, + GemmMPerXDL, + GemmNPerXDL, GemmK1, MRepeat, NRepeat, @@ -313,7 +313,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( GemmCThreadTransferDstScalarPerVector, decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat @@ -325,7 +325,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( out_gemmm_gemmn_grid_desc, in_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, + out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp new file mode 100644 index 0000000000..d9169649e6 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, + const BDesc& b_k_n_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_k_m, + const Tensor& b_k_n, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k_m_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_k_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp new file mode 100644 index 0000000000..90e258d581 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, + const BDesc& b_n_k_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_k_m, + const Tensor& b_n_k, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_n_k_grid_desc.GetLength(I0); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k_m_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_n_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp new file mode 100644 index 0000000000..ab235d97e7 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, + const BDesc& b_k_n_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_m_k_grid_desc.GetLength(I1); + const auto M = a_m_k_grid_desc.GetLength(I0); + const auto N = b_k_n_grid_desc.GetLength(I1); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_m_k_grid_desc, + make_tuple(make_pass_through_transform(M), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_k_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp new file mode 100644 index 0000000000..c68442d127 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp @@ -0,0 +1,275 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, + const BDesc& b_n_k_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_m_k, + const Tensor& b_n_k, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_m_k_grid_desc.GetLength(I1); + const auto M = a_m_k_grid_desc.GetLength(I0); + const auto N = b_n_k_grid_desc.GetLength(I0); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_m_k_grid_desc, + make_tuple(make_pass_through_transform(M), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_n_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index edfce52a19..2bf8adba84 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -17,8 +17,8 @@ template , remove_reference_t, - remove_reference_t, + remove_reference_t, remove_reference_t>; #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE @@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_c_grid, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); - DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); + DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); - c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); float ave_time = launch_and_time_kernel( @@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_c_grid, cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); #endif return ave_time; diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index 67cea94813..4e93ada859 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -41,7 +41,7 @@ int main(int argc, char* argv[]) // dynamic mode if(argc != 22) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); exit(1); } @@ -79,7 +79,7 @@ int main(int argc, char* argv[]) // static mode if(argc < 7) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); exit(1); } @@ -90,28 +90,28 @@ int main(int argc, char* argv[]) const bool do_log = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[6]); - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr auto N = Number<128>{}; + constexpr auto C = Number<192>{}; + constexpr auto Hi = Number<71>{}; + constexpr auto Wi = Number<71>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; - const index_t conv_stride_h = 2; - const index_t conv_stride_w = 2; - const index_t conv_dilation_h = 1; - const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 1; - const index_t in_left_pad_w = 1; - const index_t in_right_pad_h = 1; - const index_t in_right_pad_w = 1; + constexpr auto conv_stride_h = I2; + constexpr auto conv_stride_w = I2; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif #if 0 @@ -119,9 +119,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 32c33003c5..34d7247f3c 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -19,13 +19,13 @@ #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" -#define USE_MODE 1 -#define USE_CONV_FWD_V4R4_NCHW 1 -#define USE_CONV_FWD_V4R4R2_NHWC 1 +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_FWD_V4R4_NCHW 0 +#define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo { @@ -49,11 +49,11 @@ int main(int argc, char* argv[]) constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; -#if USE_MODE +#if USE_DYNAMIC_MODE // dynamic mode if(argc != 22) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); exit(1); } @@ -91,7 +91,7 @@ int main(int argc, char* argv[]) // static mode if(argc < 7) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); exit(1); } @@ -102,38 +102,38 @@ int main(int argc, char* argv[]) const bool do_log = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[6]); - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr auto N = Number<128>{}; + constexpr auto C = Number<192>{}; + constexpr auto Hi = Number<71>{}; + constexpr auto Wi = Number<71>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; - const index_t conv_stride_h = 2; - const index_t conv_stride_w = 2; - const index_t conv_dilation_h = 1; - const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 1; - const index_t in_left_pad_w = 1; - const index_t in_right_pad_h = 1; - const index_t in_right_pad_w = 1; + constexpr auto conv_stride_h = I2; + constexpr auto conv_stride_w = I2; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif -#if 1 +#if 0 using in_data_t = float; using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; @@ -228,7 +228,6 @@ int main(int argc, char* argv[]) } auto f_make_for_device_nchw = [&]() { -#if USE_MODE const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); const auto wei_lengths_dev = make_tuple(K, C, Y, X); const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); @@ -236,19 +235,6 @@ int main(int argc, char* argv[]) const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif return make_tuple(in_lengths_dev, wei_lengths_dev, @@ -260,7 +246,6 @@ int main(int argc, char* argv[]) }; auto f_make_for_device_nhwc = [&]() { -#if USE_MODE const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); const auto wei_lengths_dev = make_tuple(K, Y, X, C); const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); @@ -268,19 +253,6 @@ int main(int argc, char* argv[]) const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif return make_tuple(in_lengths_dev, wei_lengths_dev, diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp new file mode 100644 index 0000000000..13c73abf30 --- /dev/null +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv_bwd_weight.hpp" +#include "device_tensor.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" + +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 + +enum ConvBackwardWeightAlgo +{ + V4R4R2XDLNCHW, +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t N = std::stoi(argv[7]); + const index_t K = std::stoi(argv[8]); + const index_t C = std::stoi(argv[9]); + const index_t Y = std::stoi(argv[10]); + const index_t X = std::stoi(argv[11]); + const index_t Hi = std::stoi(argv[12]); + const index_t Wi = std::stoi(argv[13]); + + const index_t conv_stride_h = std::stoi(argv[14]); + const index_t conv_stride_w = std::stoi(argv[15]); + const index_t conv_dilation_h = std::stoi(argv[16]); + const index_t conv_dilation_w = std::stoi(argv[17]); + const index_t in_left_pad_h = std::stoi(argv[18]); + const index_t in_left_pad_w = std::stoi(argv[19]); + const index_t in_right_pad_h = std::stoi(argv[20]); + const index_t in_right_pad_w = std::stoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + constexpr auto N = Number<128>{}; + constexpr auto C = Number<128>{}; + constexpr auto Hi = Number<14>{}; + constexpr auto Wi = Number<14>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + + constexpr auto conv_stride_h = I1; + constexpr auto conv_stride_w = I1; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; + + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; + + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; +#endif + +#if 1 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + if(layout == ConvTensorLayout::NCHW) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + } + else if(layout == ConvTensorLayout::NHWC) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + } + else + { + std::runtime_error("wrong! not implemented"); + } + + Tensor in(in_lengths_host); + Tensor wei_device(wei_lengths_host); + Tensor wei_host(wei_lengths_host); + Tensor out(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei_host.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_out = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + out.GenerateTensorValue(gen_out, num_thread); + } + + auto f_make_for_device_nchw = [&]() { + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_WRW_V4R4R2_XDL_NCHW + if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_backward_weights(out, + in, + wei_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(wei_host, wei_device); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei_device: ", wei_device.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei_host : ", wei_host.mData, ",") << std::endl; + } + } +} diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp new file mode 100644 index 0000000000..42c69ff6a2 --- /dev/null +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -0,0 +1,294 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "gemm_common.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdlops_mk_kn_mn.hpp" +#include "device_gemm_xdlops_mk_nk_mn.hpp" +#include "device_gemm_xdlops_km_kn_mn.hpp" +#include "device_gemm_xdlops_km_nk_mn.hpp" + +#define USE_GEMM_XDL_MK_KN_MN 1 +#define USE_GEMM_XDL_MK_NK_MN 1 +#define USE_GEMM_XDL_KM_KN_MN 1 +#define USE_GEMM_XDL_KM_NK_MN 1 + +enum GemmAlgo +{ + Xdl_MK_KN_MN, // 0 + Xdl_MK_NK_MN, // 1 + Xdl_KM_KN_MN, // 2 + Xdl_KM_NK_MN, // 3 +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + // dynamic mode + if(argc != 10) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: M, N, K\n"); + exit(1); + } + + const auto layout = static_cast(std::stoi(argv[1])); + const auto algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t M = std::stoi(argv[7]); + const index_t N = std::stoi(argv[8]); + const index_t K = std::stoi(argv[9]); + +#if 0 + using ab_data_t = float; + using acc_data_t = float; + using c_data_t = float; +#elif 1 + using ab_data_t = half_t; + using acc_data_t = float; + using c_data_t = half_t; +#elif 1 + using ab_data_t = int8_t; + using acc_data_t = int32_t; + using c_data_t = int8_t; +#endif + + std::vector a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); + std::vector a_strides_host(2), b_strides_host(2), c_strides_host(2); + + if(layout == GemmMatrixLayout::MK_KN_MN) + { + a_lengths_host[0] = static_cast(M); + a_lengths_host[1] = static_cast(K); + a_strides_host[0] = static_cast(K); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + a_lengths_host[0] = static_cast(M); + a_lengths_host[1] = static_cast(K); + a_strides_host[0] = static_cast(K); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(N); + b_lengths_host[1] = static_cast(K); + b_strides_host[0] = static_cast(K); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + a_lengths_host[0] = static_cast(K); + a_lengths_host[1] = static_cast(M); + a_strides_host[0] = static_cast(M); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + a_lengths_host[0] = static_cast(K); + a_lengths_host[1] = static_cast(M); + a_strides_host[0] = static_cast(M); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(N); + b_lengths_host[1] = static_cast(K); + b_strides_host[0] = static_cast(K); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else + { + std::runtime_error("wrong! not implemented"); + } + + Tensor a(a_lengths_host, a_strides_host); + Tensor b(b_lengths_host, b_strides_host); + Tensor c_host(c_lengths_host, c_strides_host); + Tensor c_device(c_lengths_host, c_strides_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: "); + ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: "); + ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: "); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + auto f_make_for_device_mk_kn_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_mk_nk_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_km_kn_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_km_nk_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + +#if USE_GEMM_XDL_MK_KN_MN + if(algo == GemmAlgo::Xdl_MK_KN_MN) + { + if(layout != GemmMatrixLayout::MK_KN_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_mk_kn_mn(); + + device_gemm_xdlops_mk_kn_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_MK_NK_MN + if(algo == GemmAlgo::Xdl_MK_NK_MN) + { + if(layout != GemmMatrixLayout::MK_NK_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_mk_nk_mn(); + + device_gemm_xdlops_mk_nk_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_KN_MN + if(algo == GemmAlgo::Xdl_KM_KN_MN) + { + if(layout != GemmMatrixLayout::KM_KN_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_km_kn_mn(); + + device_gemm_xdlops_km_kn_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_NK_MN + if(algo == GemmAlgo::Xdl_KM_NK_MN) + { + if(layout != GemmMatrixLayout::KM_NK_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_km_nk_mn(); + + device_gemm_xdlops_km_nk_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + + if(do_verification) + { + host_gemm(a, b, c_host, layout); + + check_error(c_host, c_device); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_device.mData, ",") << std::endl; + } + } +} diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp new file mode 100644 index 0000000000..f0f35a78b9 --- /dev/null +++ b/host/host_tensor/include/gemm_common.hpp @@ -0,0 +1,12 @@ +#ifndef GEMM_COMMON_HPP +#define GEMM_COMMON_HPP + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +#endif diff --git a/host/host_tensor/include/host_conv_bwd_weight.hpp b/host/host_tensor/include/host_conv_bwd_weight.hpp new file mode 100644 index 0000000000..ed3e8c3042 --- /dev/null +++ b/host/host_tensor/include/host_conv_bwd_weight.hpp @@ -0,0 +1,89 @@ +#pragma once +#include "host_tensor.hpp" + +template +void host_direct_convolution_backward_weights( + const Tensor& out, + const Tensor& in, + Tensor& wei, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(out(n, k, ho, wo)); + } + } + } + } + wei(k, c, y, x) = v; + }; + + auto f_kyxc = [&](auto k, auto y, auto x, auto c) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(out(n, ho, wo, k)); + } + } + } + } + wei(k, y, x, c) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_kcyx, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_kyxc, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp new file mode 100644 index 0000000000..97cf245054 --- /dev/null +++ b/host/host_tensor/include/host_gemm.hpp @@ -0,0 +1,87 @@ +#pragma once +#include "host_tensor.hpp" +#include "gemm_common.hpp" + +template +void host_gemm(const Tensor& a, + const Tensor& b, + Tensor& c, + const GemmMatrixLayout layout) +{ + if(layout == GemmMatrixLayout::MK_KN_MN) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + auto f_mk_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + auto f_km_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + auto f_km_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} diff --git a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp index 97ce326346..361f6e4a26 100644 --- a/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/solver/include/conv_tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp @@ -9,8 +9,8 @@ struct tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw int NPerBlock; int KPerBlock; - int MPerWave; - int NPerWave; + int MPerXDL; + int NPerXDL; int K1; int MRepeat; @@ -45,8 +45,8 @@ static tunable_dyn_conv_fwd_v4r4_xdlops_nchw_kcyx_nkhw 128, // MPerBlock, 128, // NPerBlock, 4, // KPerBlock, - 32, // MPerWave, - 32, // NPerWave, + 32, // MPerXDL, + 32, // NPerXDL, 4, // K1, 2, // MRepeat, 2, // NRepeat, diff --git a/script/run.sh b/script/run.sh index ecb5c85d81..3b383fcf3a 100755 --- a/script/run.sh +++ b/script/run.sh @@ -12,13 +12,16 @@ #export OLC_DEBUG_HIP_DUMP=1 #export OLC_DEBUG_SAVE_TEMP_DIR=1 - make -j conv_fwd_driver_offline - make -j conv_bwd_driver_offline - make -j conv_fwd_driver_online - #rm -rf /root/_hip_binary_kernels_/ #rm -rf /tmp/olCompile* +#make -j conv_fwd_driver_offline +#make -j conv_bwd_driver_offline +#make -j conv_wrw_driver_offline +#make -j conv_fwd_driver_online + + make -j gemm_driver_offline + LAYOUT=$1 ALGO=$2 VERIFY=$3 @@ -30,7 +33,7 @@ REPEAT=$6 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 - ./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 @@ -44,4 +47,12 @@ REPEAT=$6 #./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./host/driver_offline/conv_wrw_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 + +#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 + +################################################ layout algo verify init log repeat M___ N___ K___ +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 + ./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192