From ba6f79a75e65610871fd5139311817642292085c Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 19 Aug 2021 01:00:41 -0500 Subject: [PATCH 01/15] Added host_conv_wrw for verification (#15) * added host conv wrw --- host/host_tensor/include/host_conv_wrw.hpp | 89 ++++++++++++++++++++++ 1 file changed, 89 insertions(+) create mode 100644 host/host_tensor/include/host_conv_wrw.hpp diff --git a/host/host_tensor/include/host_conv_wrw.hpp b/host/host_tensor/include/host_conv_wrw.hpp new file mode 100644 index 0000000000..ed3e8c3042 --- /dev/null +++ b/host/host_tensor/include/host_conv_wrw.hpp @@ -0,0 +1,89 @@ +#pragma once +#include "host_tensor.hpp" + +template +void host_direct_convolution_backward_weights( + const Tensor& out, + const Tensor& in, + Tensor& wei, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads&, + const ConvTensorLayout layout = ConvTensorLayout::NCHW) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + auto f_kcyx = [&](auto k, auto c, auto y, auto x) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[2]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[3]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 && + wi < in.mDesc.GetLengths()[3]) + { + v += static_cast(in(n, c, hi, wi)) * + static_cast(out(n, k, ho, wo)); + } + } + } + } + wei(k, c, y, x) = v; + }; + + auto f_kyxc = [&](auto k, auto y, auto x, auto c) { + double v = 0; + for(int n = 0; n < out.mDesc.GetLengths()[0]; ++n) + { + for(int ho = 0; ho < out.mDesc.GetLengths()[1]; ++ho) + { + int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0]; + for(int wo = 0; wo < out.mDesc.GetLengths()[2]; ++wo) + { + int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1]; + if(hi >= 0 && hi < in.mDesc.GetLengths()[1] && wi >= 0 && + wi < in.mDesc.GetLengths()[2]) + { + v += static_cast(in(n, hi, wi, c)) * + static_cast(out(n, ho, wo, k)); + } + } + } + } + wei(k, y, x, c) = v; + }; + + if(layout == ConvTensorLayout::NCHW) + { + make_ParallelTensorFunctor(f_kcyx, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else if(layout == ConvTensorLayout::NHWC) + { + make_ParallelTensorFunctor(f_kyxc, + wei.mDesc.GetLengths()[0], + wei.mDesc.GetLengths()[1], + wei.mDesc.GetLengths()[2], + wei.mDesc.GetLengths()[3])(std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} From a2ad6d35315555cddeb7cf6e76fc5eee3864e6f6 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Thu, 19 Aug 2021 09:54:10 -0500 Subject: [PATCH 02/15] refactor dynamic xdlops iGemm (#13) * xdlops refactor * fixed commnt * clean xdlops_gemm * add make c into xldops-gemm * change mfma_info * refactor xdlops, hide c desc * clean * clean * clean * apply hacks changes to v4r4r4_nhwc * rename hacks and use single stage adapter * enable fp16 mfma --- .../blockwise_gemm_xdlops.hpp | 585 ++++------- .../gridwise_gemm_xdlops_v2r3.hpp | 264 ++--- .../include/tensor_operation/xdlops_gemm.hpp | 947 +++++++++--------- ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 40 +- ...icit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp | 229 ----- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 32 +- .../include/driver_gemm_xdlops_v2r3.hpp | 16 +- .../src/conv_fwd_driver_offline.cpp | 4 +- 8 files changed, 790 insertions(+), 1327 deletions(-) delete mode 100644 host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index ee6a0b7427..0c381de0d8 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -9,16 +9,15 @@ namespace ck { template -struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { - - using CIndex = MultiIndex<2>; - static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -26,111 +25,165 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 static constexpr index_t WaveSize = 64; - static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t KPerBlock = K0; - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; - static constexpr index_t MWaves = M1 / MPerWave; - static constexpr index_t NWaves = N1 / NPerWave; + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static constexpr index_t MRepeat = M0; - static constexpr index_t NRepeat = N0; + __device__ static auto GetWaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); - __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + const auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); - __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } - - __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } __device__ static auto CalculateAThreadOriginDataIndex() { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; + const auto wave_idx = GetWaveIdx(); - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, m_offset, 0); - } - else - { - const index_t m_offset = waveId_m * MPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, m_offset, 0); - } + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0); } __device__ static auto CalculateBThreadOriginDataIndex() { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_n = waveId % NWaves; + const auto wave_idx = GetWaveIdx(); - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, n_offset, 0); - } - else - { - const index_t n_offset = waveId_n * NPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, n_offset, 0); - } + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0); } template - __device__ static CIndex + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { + const auto wave_idx = GetWaveIdx(); - const index_t waveId = get_thread_local_1d_id() / WaveSize; + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); - const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; - const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); - return CIndex{m_offset, n_offset}; + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); } - __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() - : a_thread_copy_{CalculateAThreadOriginDataIndex()}, - b_thread_copy_{CalculateBThreadOriginDataIndex()} + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); + static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0), + "wrong! K0 dimension not consistent"); - static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2), "wrong! K1 dimension not consistent"); static_assert(BlockSize == MWaves * NWaves * WaveSize, "BlockSize != MWaves * NWaves * WaveSize\n"); - static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - - static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); } + __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() + { + constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc); + } + + template + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); + } + + __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor(); + static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor(); + template __device__ void Run(const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, @@ -141,49 +194,48 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + vector_type a_thread_vec; - vector_type a_thread_vec; + vector_type b_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { + static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) { // read A - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0, I0), + a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, + make_tuple(k0, I0, I0, I0, I0), a_block_buf, a_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, I0, I0, I0, I0), a_thread_buf); // read B - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0, I0), + b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, + make_tuple(k0, I0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, I0, I0, I0, I0), b_thread_buf); - using mfma_input_type = - typename vector_type::type; - - static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { - a_thread_vec.template AsType()(Number{}) = a_thread_buf[Number{}]; - }); - - static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { - b_thread_vec.template AsType()(Number{}) = b_thread_buf[Number{}]; - }); + using mfma_input_type = typename vector_type::type; static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf); + static_for<0, K1, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + }); + + static_for<0, K1, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf); }); }); }); @@ -191,333 +243,38 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 private: // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, I1, Number{})); // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, I1, Number{})); - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, MRepeat, 1, 1, K1>, + Sequence<0, 1, 2, 3, 4>, + 4, K1, 1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, NRepeat, 1, 1, K1>, + Sequence<0, 1, 2, 3, 4>, + 4, K1, 1>; - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -template -struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline -{ - - using CIndex = MultiIndex<2>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto xdlops_gemm = XdlopsGemm{}; - - static constexpr index_t WaveSize = 64; - - static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); - - static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); - - static constexpr index_t MWaves = M1 / MPerWave; - static constexpr index_t NWaves = N1 / NPerWave; - - static constexpr index_t MRepeat = M0; - static constexpr index_t NRepeat = N0; - - __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } - - __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } - - __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; - - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, m_offset, 0); - } - else - { - const index_t m_offset = waveId_m * MPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, m_offset, 0); - } - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_n = waveId % NWaves; - - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, n_offset, 0); - } - else - { - const index_t n_offset = waveId_n * NPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, n_offset, 0); - } - } - - template - __device__ static CIndex - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) - { - - const index_t waveId = get_thread_local_1d_id() / WaveSize; - - const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - - const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; - - const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; - const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; - - return CIndex{m_offset, n_offset}; - } - - __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() - : a_thread_copy_{CalculateAThreadOriginDataIndex()}, - b_thread_copy_{CalculateBThreadOriginDataIndex()} - { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); - - static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), - "wrong! K1 dimension not consistent"); - - static_assert(BlockSize == MWaves * NWaves * WaveSize, - "BlockSize != MWaves * NWaves * WaveSize\n"); - - static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - - static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); - } - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I1, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I1, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - static_for{}([&](auto k) { - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I1, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I1, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - }); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - } - - private: - // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - - // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 207f73072f..51af11fefc 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -18,7 +18,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -29,7 +29,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const AK0MK1GridDesc a_k0_m_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc, const CBlockClusterAdaptor c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -43,7 +43,7 @@ __global__ void p_shared_block, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER @@ -52,7 +52,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -63,7 +63,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const void CONSTANT* p_c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -73,8 +73,9 @@ __global__ void cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); const auto b_k0_n_k1_grid_desc = *reinterpret_cast( cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); - const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc)); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); const auto c_block_cluster_adaptor = *reinterpret_cast( cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); @@ -86,7 +87,7 @@ __global__ void p_shared_block, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); } #endif @@ -138,6 +139,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -201,29 +205,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } __host__ __device__ static constexpr auto - MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) { - constexpr auto xdlops_gemm = XdlopsGemm{}; + constexpr auto max_lds_align = K1; - constexpr auto CLayout = xdlops_gemm.GetCLayout(); + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); - constexpr auto M0 = Number{}; - constexpr auto M1 = Number{}; - constexpr auto M2 = Number{}; + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); - constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); - constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - constexpr auto N1 = Number{}; - - const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), - make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); - - return c_m0_m1_m2_n_grid_desc; + return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); } __host__ __device__ static constexpr auto @@ -253,8 +256,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return c_blockid_to_m0_n0_block_cluster_adaptor; } - using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -262,7 +265,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAB* __restrict__ p_shared_block, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CBlockClusterAdaptor& c_block_cluster_adaptor) { const auto a_grid_buf = make_dynamic_buffer( @@ -270,7 +273,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); @@ -358,50 +361,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // register // sanity check - static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && - NPerBlock % (NPerWave * NRepeat) == 0, - "wrong!"); - - constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor( - a_k0_m_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor( - b_k0_n_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - const auto blockwise_gemm = - BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; - - constexpr auto CLayout = blockwise_gemm.GetCLayout(); - - constexpr index_t BlkSize = CLayout.GetBlkSize(); - constexpr index_t NumBlks = CLayout.GetNumBlks(); - constexpr index_t NumXdlops = CLayout.GetNumXdlops(); - - static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; constexpr auto c_mr_nr_blk_desc = make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor(); + constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize(); + StaticBuffer, + vector_type, c_mr_nr_blk_desc.GetElementSpaceSize(), true> c_thread_buf; @@ -474,41 +453,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } -#if 0 // output: register to global memory { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); - constexpr index_t N0 = CLayout.N1(); - constexpr index_t N1 = CLayout.N0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number<1>{}, - Number<1>{}, - Number{}, - Number<1>{}, - Number{}, - Number<1>{})); - - StaticBuffer - c_blk_buf_; - - static_for<0, MRepeat, 1>{}([&](auto mr_i) { - static_for<0, NRepeat, 1>{}([&](auto nr_i) { - constexpr auto blk_off = - c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i)); - - static_for<0, BlkSize, 1>{}([&](auto j) { - c_blk_buf_(Number{}) = - c_thread_buf[Number{}] - .template AsType()[Number{}]; - }); - }); - }); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -521,145 +473,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; - - constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); - constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(c_m0_m1_m2_n_thread_desc), - decltype(c_m0_m1_m2_n_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_m0_m1_m2_n_grid_desc, - make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves), - n_thread_data_on_grid / (N1 * NWaves), - m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0), - n_thread_data_on_grid % (N1 * NWaves) / N1, - m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, - n_thread_data_on_grid % N1)} - .Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_blk_buf_, - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); - } -#else - { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), + Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, true>{ - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, make_multi_index(0, 0, 0, 0, - m_thread_data_on_grid / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, + m_thread_data_on_grid / (M3 * M4), + m_thread_data_on_grid % (M3 * M4) / M4, + m_thread_data_on_grid % M4, n_thread_data_on_grid)}; auto init_copy = [&](auto c_thread_idx_) { constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); return c_thread_idx_; }; auto mrepeat_plus_copy = [&](auto c_thread_idx_) { constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + mrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto nrepeat_plus_copy = [&](auto c_thread_idx_) { constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + nrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto mrepeat_minus_copy = [&](auto c_thread_idx_) { constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + mrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto nrepeat_minus_copy = [&](auto c_thread_idx_) { constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + nrepeat_step_minus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or @@ -791,7 +694,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 init_copy(make_tuple(I0, I0)); } } -#endif } }; // namespace ck diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index affe096ace..56581c024b 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -7,21 +7,18 @@ namespace ck { -enum struct mfma_instr +enum struct MfmaInstr { - /// fp32 mfma_f32_32x32x1xf32 = 0, mfma_f32_16x16x1xf32, mfma_f32_4x4x1xf32, mfma_f32_32x32x2xf32, // k reduction mfma_f32_16x16x4xf32, // k reduction - /// fp16 mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, mfma_f32_32x32x8f16, // k reduction mfma_f32_16x16x16f16, // k reduction - /// bfp16 mfma_f32_32x32x2bf16, mfma_f32_16x16x2bf16, mfma_f32_4x4x2bf16, @@ -29,25 +26,23 @@ enum struct mfma_instr mfma_f32_16x16x8bf16, // k reduction }; -template -struct mfma_info; +template +struct mfma_type; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 1; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 2; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 4; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 1; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template // treat 4x4x1 as a single-blk 4x64 mfma template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 1; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 4; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 8; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 16; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 4; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 4; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template #if 0 template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 2; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 4; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 8; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 2; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 2; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; #endif -template -struct xdlops_info +template +struct MfmaSelector { - static constexpr auto mfma_type = mfma_info{}; + template + static constexpr auto GetMfma(); - static constexpr index_t MPerXdlops = MPerXdlops_; - static constexpr index_t NPerXdlops = NPerXdlops_; + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x2xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x8f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x16f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + +#if 0 + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } +#endif + + static constexpr auto selected_mfma = mfma_type()>{}; + + __host__ __device__ static constexpr void mfma_check() + { + static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == + selected_mfma.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks == + selected_mfma.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks || + selected_mfma.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size == + selected_mfma.m_per_blk * selected_mfma.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_mfma.is_k_reduction || + (selected_mfma.num_input_blks == selected_mfma.num_output_blks), + "is_k_reduction wrong!"); + } + + __host__ __device__ constexpr MfmaSelector() { mfma_check(); } static constexpr bool IsABroadcast() { @@ -505,186 +652,33 @@ struct xdlops_info return true; } - static constexpr bool IsKReduction() - { - return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); - } - static constexpr index_t GetKPerXdlops() { - return IsKReduction() ? mfma_type.num_input_blks : 1; + return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) * + selected_mfma.k_per_blk; } - static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } + static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } }; -template +template struct XdlopsGemm { - template - static constexpr auto GetXdlopsInfo(); - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - -#if 0 - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } -#endif + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; using CIndex = MultiIndex<2>; - __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } __device__ static constexpr index_t GetNumXdlops() { - return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + return MPerXdlops * NPerXdlops / + (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks); } __host__ __device__ constexpr XdlopsGemm() @@ -697,104 +691,141 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); - static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, - "m != num_input_blks * num_regs_blk"); - static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks || - mfma_type.num_output_blks == 1, - "incorrect num_output_blks"); - static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n, - "num_regs_blk incorrect"); + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } - static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); + template + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc) + { + const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0); + const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1); + const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2); + const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3); + + return transform_tensor_descriptor( + c_m0_n0_m1_n1_m2_n2_desc, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, + mfma_instr.num_input_blks, + mfma_instr.group_size)), + make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); } __device__ static constexpr index_t GetRegSizePerXdlops() { - return MPerXdlops * NPerXdlops / mfma_type.wave_size; + return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } - template + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert(is_same::value || is_same::value || is_same::value, "base base_type must be float, half, ushort!"); - static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); - - constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); - - static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { - constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); - constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); - - mfma_type.template run( - p_a_wave[Number{}], - p_b_wave[Number{}], - p_c_thread); + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + mfma_instr.template run( + p_a_wave[k], p_b_wave[k], p_c_thread); }); } + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) { - const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; - const index_t blk_id = laneId / mfma_type.num_threads_blk; - const index_t blk_td = laneId % mfma_type.num_threads_blk; + const auto blk_idx = GetBlkIdx(); - index_t n_offset = blk_i * mfma_type.n + blk_td; - index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size; return CIndex{m_offset, n_offset}; } - static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; - static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; - static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; - static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; + static constexpr auto mfma = MfmaSelector{}; - static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); - static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); - static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); + static constexpr auto mfma_instr = mfma.selected_mfma; - static constexpr auto GetBlkId(const index_t lane_id) + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto KPerThread = mfma.GetKPerThread(); + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() { - return lane_id / mfma_type.num_threads_blk; + return make_tuple( + Number{}, I1, Number{}, I1); } - - static constexpr auto GetBlkTd(const index_t lane_id) - { - return lane_id % mfma_type.num_threads_blk; - } - - static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; - - struct CLayout - { - __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } - __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } - __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } - __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } - - __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } - - __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } - - __device__ static constexpr index_t GetNumXdlops() - { - return MPerXdlops * NPerXdlops / - (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); - } - }; - - __host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } }; } // namespace ck diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 695ffeeb36..34de1ac0ed 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -48,10 +48,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); #if 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; - constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; @@ -59,10 +59,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( constexpr index_t GemmNPerWave = 32; constexpr index_t GemmK1 = 8; - constexpr index_t MRepeat = 4; + constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; @@ -106,22 +106,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 141a326574..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,229 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r2.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1>, - 2, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 7067291c8a..a392df6aa8 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -250,22 +250,22 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0>{}), // 7+: N1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index edfce52a19..a33e5bf4d0 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); } - const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); + using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); @@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, FloatC, remove_reference_t, remove_reference_t, - remove_reference_t, + remove_reference_t, remove_reference_t>; #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE @@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_c_grid, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); - DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); + DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); - c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); float ave_time = launch_and_time_kernel( @@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_c_grid, cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); #endif return ave_time; diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 32c33003c5..1ab8a822cd 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -24,8 +24,8 @@ #define USE_CONV_FWD_V4R4R2_NHWC 1 #define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo { From c6f26bb4806e139d1e312aac3bc653a31a1d6946 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Mon, 23 Aug 2021 10:40:27 -0500 Subject: [PATCH 03/15] magic division use __umulhi() (#19) --- .../include/utility/magic_division.hpp | 25 +++---------------- 1 file changed, 3 insertions(+), 22 deletions(-) diff --git a/composable_kernel/include/utility/magic_division.hpp b/composable_kernel/include/utility/magic_division.hpp index b7489016e9..612aceea2a 100644 --- a/composable_kernel/include/utility/magic_division.hpp +++ b/composable_kernel/include/utility/magic_division.hpp @@ -114,12 +114,11 @@ struct MagicDivision __host__ __device__ static constexpr uint32_t DoMagicDivision(uint32_t dividend, uint32_t multiplier, uint32_t shift) { - uint32_t tmp = (uint64_t(dividend) * uint64_t(multiplier)) >> 32; + uint32_t tmp = __umulhi(dividend, multiplier); return (tmp + dividend) >> shift; } -#if 1 // debug - // HACK: magic division for int32_t + // magic division for int32_t // HACK: use dividend_i32 as if it's uint32_t, dividend_i32 need to be // non-negative for result to be correct // TODO: figure out how to do magic number divison for int32_t as dividended @@ -127,27 +126,9 @@ struct MagicDivision DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) { uint32_t dividend_u32 = as_type(dividend_i32); - uint32_t tmp = - (static_cast(dividend_u32) * static_cast(multiplier)) >> 32; + uint32_t tmp = __umulhi(dividend_u32, multiplier); return (tmp + dividend_u32) >> shift; } -#else - // the inline ASM is producing wrong result - __host__ __device__ static int32_t - DoMagicDivision(int32_t dividend_i32, uint32_t multiplier, uint32_t shift) - { - uint32_t r; - asm volatile("\n \ - v_mul_hi_u32 %0, %1, %2 \n \ - v_add_u32_e32 %0, %1, %0 \n \ - v_lshrrev_b32_e32 %0, %3, %0 \n \ - " - : "=v"(r) - : "v"(as_type(dividend_i32)), "s"(multiplier), "s"(shift)); - - return as_type(r); - } -#endif }; } // namespace ck From 9d3f634a3cc37599db9f3ebfeb1f9a3c45d9a673 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Mon, 23 Aug 2021 11:22:10 -0500 Subject: [PATCH 04/15] Xdlops refactor fix (#22) * added constexpr ahead of adptor; clean unused driver; rename M/NPerWave to M/NPerXDL * fixed bwd * fixed comment --- .../blockwise_gemm_xdlops.hpp | 3 +- .../gridwise_gemm_xdlops_v2r3.hpp | 20 +- .../include/tensor_operation/xdlops_gemm.hpp | 5 +- ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 74 ++--- ...icit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp | 40 +-- ...plicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 280 ------------------ ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 6 +- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 44 +-- .../include/driver_gemm_xdlops_v2r3.hpp | 8 +- .../src/conv_fwd_driver_offline.cpp | 6 +- ...tunable_fwd_v4r4_xdlops_nchw_kcyx_nkhw.hpp | 8 +- 11 files changed, 111 insertions(+), 383 deletions(-) delete mode 100644 host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index 0c381de0d8..a8236737df 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -4,6 +4,7 @@ #include "common_header.hpp" #include "threadwise_tensor_slice_transfer.hpp" #include "xdlops_gemm.hpp" +#include "tensor_adaptor.hpp" namespace ck { @@ -40,7 +41,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { const index_t thread_id = get_thread_local_1d_id(); - const auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + constexpr auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), make_tuple(Sequence<0, 1, 2>{}), make_tuple(Sequence<0>{})); diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 51af11fefc..3e4d74e9d8 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -103,8 +103,8 @@ template ; @@ -366,8 +368,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAB, decltype(a_k0_m_k1_block_desc), decltype(b_k0_n_k1_block_desc), - MPerWave, - NPerWave, + MPerXDL, + NPerXDL, MRepeat, NRepeat, K1>{}; diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index 56581c024b..f945b0fdf5 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -749,8 +749,9 @@ struct XdlopsGemm __device__ static auto GetBlkIdx() { - const auto laneId = GetLaneId(); - const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + const auto laneId = GetLaneId(); + + constexpr auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( make_tuple(make_merge_transform( make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), make_tuple(Sequence<0, 1, 2>{}), diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 7bd82bf6d5..7196f3c179 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,9 +56,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -84,9 +84,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -112,9 +112,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -140,9 +140,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 4; @@ -168,9 +168,9 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -223,25 +223,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 - constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 7+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 7-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + //clang-format on constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}; @@ -263,8 +265,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, + GemmMPerXDL, + GemmNPerXDL, GemmK1, MRepeat, NRepeat, @@ -289,7 +291,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat @@ -301,7 +303,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( in_gemmm_gemmn_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_m1_m2_n_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 0ebf8571f4..2cbae2daf3 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -195,25 +195,27 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmn Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 - constexpr auto in_m0_m1_m2_n_grid_step_hacks = make_tuple( + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple( - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + //clang-format on constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; @@ -265,7 +267,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk GemmCThreadTransferDstScalarPerVector, decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_m1_m2_n_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), true // CAccessOrderMRepeatNRepeat @@ -277,7 +279,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk in_gemmm_gemmn_grid_desc, out_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_m1_m2_n_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp deleted file mode 100644 index 4a9d01081c..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp +++ /dev/null @@ -1,280 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "driver_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw( - const InLengths& in_n_c_hi_wi_lengths, - const WeiLengths& wei_k_c_y_x_lengths, - const OutLengths& out_n_k_ho_wo_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_c_hi_wi, - const Tensor& wei_k_c_y_x, - Tensor& out_n_k_ho_wo, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); - DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); - - in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); - wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); - out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); - - const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); - const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); - const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); - -#if 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 0 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmKPack = 4; - - constexpr index_t MRepeat = 1; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_KPack = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; - constexpr index_t GemmBBlockTransferDstScalarPerVector_KPack = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector_GemmN1 = 1; -#endif - - const auto descs = -#if 1 - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_pad -#else - transform_forward_convolution_into_gemm_v4r4_xdlops_nchw_kcyx_nkhw_1x1 -#endif - ( - wei_k_c_y_x_desc, - in_n_c_hi_wi_desc, - out_n_k_ho_wo_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads); - - for(index_t i = 0; i < 5; ++i) - { -#if 0 - float ave_time = launch_kernel_gemm_xdlops_v1 -#else - float ave_time = launch_kernel_gemm_xdlops_v2 -#endif - , - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK, - GemmABlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<0, 2, 1>, - Sequence<1, 0, 2>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_KPack, - false, // don't move back src coordinate after threadwise copy, which will be fused - // with MoveSrcSliceWindow() to save addr computation - Sequence<2, 3, 0, 1>, - 3, - GemmCThreadTransferDstScalarPerVector_GemmN1, - decltype(descs[I4]), - decltype(descs[I5]), - decltype(descs[I6]), - decltype(descs[I7]), - decltype(descs[I8])>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - descs[I0], - descs[I1], - descs[I2], - descs[I3], - descs[I4], - descs[I5], - descs[I6], - descs[I7], - descs[I8], - nrepeat); - - float perf = (float)calculate_convolution_flops( - in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; - } - - // copy result back to host - out_n_k_ho_wo_device_buf.FromDevice(out_n_k_ho_wo.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 34de1ac0ed..dc4f5eafb6 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - constexpr auto out_m0_m1_m2_n_grid_step_hacks = + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, @@ -169,7 +169,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( GemmCThreadTransferDstScalarPerVector, decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false>(static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), @@ -180,7 +180,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( out_gemmm_gemmn_grid_desc, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, + out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index a392df6aa8..5ff8dfb665 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -56,8 +56,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 4; @@ -84,9 +84,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -112,9 +112,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 4; @@ -140,9 +140,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 4; constexpr index_t NRepeat = 2; @@ -168,9 +168,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 256; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 4; @@ -196,9 +196,9 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; @@ -249,7 +249,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 - constexpr auto out_m0_m1_m2_n_grid_step_hacks = + constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves @@ -287,8 +287,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( GemmMPerBlock, GemmNPerBlock, GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, + GemmMPerXDL, + GemmNPerXDL, GemmK1, MRepeat, NRepeat, @@ -313,7 +313,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( GemmCThreadTransferDstScalarPerVector, decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), + decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), false // CAccessOrderMRepeatNRepeat @@ -325,7 +325,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( out_gemmm_gemmn_grid_desc, in_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, + out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, nrepeat); diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index a33e5bf4d0..2bf8adba84 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -17,8 +17,8 @@ template Date: Wed, 25 Aug 2021 10:55:55 -0500 Subject: [PATCH 05/15] GlobalAtomicAdd for fp32/int32 (#23) * add f32/i32 atomicAdd support into dynamicBuffer, and enable it in v1r3 * fixed * fixed * update comment Co-authored-by: Chao Liu --- .../threadwise_tensor_slice_transfer.hpp | 18 +- .../include/utility/amd_buffer_addressing.hpp | 175 +++++++++++++++++- composable_kernel/include/utility/config.hpp | 4 +- .../include/utility/dynamic_buffer.hpp | 32 ++++ 4 files changed, 222 insertions(+), 7 deletions(-) diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 0071accf7f..e38dbbc8b5 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -208,10 +208,20 @@ struct ThreadwiseTensorSliceTransfer_v1r3 coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_desc, dst_coord_); // copy data from dst_vector into dst_buf - dst_buf.template Set( - dst_coord_.GetOffset(), - is_dst_valid, - dst_vector.template AsType()[Number<0>{}]); + if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::Set) + { + dst_buf.template Set( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } + else if constexpr(DstInMemOp == InMemoryDataOperationEnum_t::AtomicAdd) + { + dst_buf.template AtomicAdd( + dst_coord_.GetOffset(), + is_dst_valid, + dst_vector.template AsType()[Number<0>{}]); + } constexpr auto move_on_dim = [&]() constexpr { diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index 57081b7fd7..b7fd4bc409 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -202,6 +202,22 @@ llvm_amdgcn_raw_buffer_store_fp32x4(float4_t vdata, index_t voffset, index_t soffset, index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.store.v4f32"); +// atomic add +// int +__device__ int32_t llvm_amdgcn_raw_buffer_atomic_add_i32( + int32_t vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.add.i32"); + +// float +__device__ float llvm_amdgcn_raw_buffer_atomic_add_fp32( + float vdata, + int32x4_t rsrc, + index_t voffset, + index_t soffset, + index_t glc_slc) __asm("llvm.amdgcn.raw.buffer.atomic.fadd.f32"); template __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_wave_buffer_resource, @@ -581,6 +597,128 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src } } +template +__device__ void amd_buffer_atomic_add_impl(const typename vector_type::type src_thread_data, + int32x4_t dst_wave_buffer_resource, + index_t dst_thread_addr_offset, + index_t dst_wave_addr_offset) +{ + static_assert((is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4)), + "wrong! not implemented"); + + if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_fp32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(float), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_fp32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(float), + 0); + } + } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_atomic_add_i32(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + } + else if constexpr(N == 4) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<2>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 2 * sizeof(int32_t), + 0); + + llvm_amdgcn_raw_buffer_atomic_add_i32(tmp.AsType()[Number<3>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 3 * sizeof(int32_t), + 0); + } + } +} + // buffer_load requires: // 1) p_src_wave must be in global memory space // 2) p_src_wave must be a wavewise pointer. @@ -645,7 +783,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, // buffer_store requires: // 1) p_dst_wave must be global memory -// 2) p_dst_wave to be a wavewise pointer. +// 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template __device__ void amd_buffer_store(const typename vector_type_maker::type::type src_thread_data, @@ -677,5 +815,40 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t #endif } +// buffer_atomic_add requires: +// 1) p_dst_wave must be global memory +// 2) p_dst_wave must be a wavewise pointer. +// It is user's responsibility to make sure that is true. +template +__device__ void +amd_buffer_atomic_add(const typename vector_type_maker::type::type src_thread_data, + T* p_dst_wave, + const index_t dst_thread_element_offset, + const bool dst_thread_element_valid, + const index_t dst_element_space_size) +{ + const int32x4_t dst_wave_buffer_resource = + make_wave_buffer_resource(p_dst_wave, dst_element_space_size); + + index_t dst_thread_addr_offset = dst_thread_element_offset * sizeof(T); + + using vector_t = typename vector_type_maker::type::type; + using scalar_t = typename scalar_type::type; + constexpr index_t vector_size = scalar_type::vector_size; + +#if CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK + uint32_t dst_addr_shift = dst_thread_element_valid ? 0 : 0x7fffffff; + + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_addr_shift + dst_thread_addr_offset, 0); +#else + if(dst_thread_element_valid) + { + amd_buffer_atomic_add_impl( + src_thread_data, dst_wave_buffer_resource, dst_thread_addr_offset, 0); + } +#endif +} + } // namespace ck #endif diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index 521ad24d47..c229162d9b 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -85,8 +85,8 @@ #define CK_EXPERIMENTAL_USE_BUFFER_STORE_OOB_CHECK_OFFSET_TRICK 1 #endif -#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK -#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_OOB_CHECK_OFFSET_TRICK 1 +#ifndef CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK +#define CK_EXPERIMENTAL_USE_BUFFER_ATOMIC_ADD_OOB_CHECK_OFFSET_TRICK 1 #endif // pass tensor descriptor by value or void* diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index 4d583e3ce7..a875afd9be 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -223,6 +223,38 @@ struct DynamicBuffer } } + template >>::type, + typename scalar_type>>::type>::value, + bool>::type = false> + __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) + { + // X contains multiple T + constexpr index_t scalar_per_t_vector = + scalar_type>>::vector_size; + + constexpr index_t scalar_per_x_vector = + scalar_type>>::vector_size; + + static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, + "wrong! X need to be multiple T"); + + static_assert(GetAddressSpace() == AddressSpaceEnum_t::Global, "only support global mem"); + +#if CK_USE_AMD_BUFFER_ADDRESSING + constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; + + amd_buffer_atomic_add>, t_per_x>( + x, p_data_, i, is_valid_element, element_space_size_); +#else + if(is_valid_element) + { + atomicAdd(&p_data_[i], x); + } +#endif + } + __host__ __device__ static constexpr bool IsStaticBuffer() { return false; } __host__ __device__ static constexpr bool IsDynamicBuffer() { return true; } From 10bb81106072e7f9de1c7ce0ed7880e41bd9f517 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 26 Aug 2021 20:05:19 -0500 Subject: [PATCH 06/15] Misc fixes (#24) * use cast_pointer_to_generic_address_space() in v6r1 kernel wrapper, DynamcBuffer and buffer_load take customized invalid-element-value, add buffer_load/store for fp64 * use remove_cvref_t --- .../tensor_description/tensor_adaptor.hpp | 3 +- .../tensor_description/tensor_descriptor.hpp | 7 +- .../blockwise_gemm_dlops_v3.hpp | 12 +- .../threadwise_contraction_dlops.hpp | 42 ++-- .../threadwise_gemm_dlops_v3.hpp | 21 +- .../threadwise_tensor_slice_set.hpp | 4 +- .../threadwise_tensor_slice_transfer.hpp | 59 +++--- .../threadwise_tensor_slice_transfer_v2.hpp | 35 ++-- .../include/utility/amd_buffer_addressing.hpp | 183 +++++++++++------- composable_kernel/include/utility/array.hpp | 2 +- .../include/utility/data_type.hpp | 11 ++ .../include/utility/dynamic_buffer.hpp | 122 ++++++------ composable_kernel/include/utility/tuple.hpp | 2 +- .../include/utility/tuple_helper.hpp | 4 +- composable_kernel/include/utility/type.hpp | 3 + ...mplicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp | 9 +- 16 files changed, 267 insertions(+), 252 deletions(-) diff --git a/composable_kernel/include/tensor_description/tensor_adaptor.hpp b/composable_kernel/include/tensor_description/tensor_adaptor.hpp index 3b647e433a..50a8088bba 100644 --- a/composable_kernel/include/tensor_description/tensor_adaptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_adaptor.hpp @@ -189,8 +189,7 @@ struct TensorAdaptor bool is_known = true; static_for<0, Transforms::Size(), 1>{}([&](auto i) { - is_known &= - remove_cv_t>::IsKnownAtCompileTime(); + is_known &= remove_cvref_t::IsKnownAtCompileTime(); }); return is_known && is_known_at_compile_time::value; diff --git a/composable_kernel/include/tensor_description/tensor_descriptor.hpp b/composable_kernel/include/tensor_description/tensor_descriptor.hpp index a6a57ba63b..8f6a5a3e43 100644 --- a/composable_kernel/include/tensor_description/tensor_descriptor.hpp +++ b/composable_kernel/include/tensor_description/tensor_descriptor.hpp @@ -185,8 +185,7 @@ struct TensorDescriptor bool is_known = true; static_for<0, Transforms::Size(), 1>{}([&](auto i) { - is_known &= - remove_cv_t>::IsKnownAtCompileTime(); + is_known &= remove_cvref_t::IsKnownAtCompileTime(); }); return is_known && is_known_at_compile_time::value && @@ -587,11 +586,11 @@ __host__ __device__ constexpr bool coordinate_has_valid_offset(const TensorDesc& template using TensorCoordinate_t = decltype(make_tensor_coordinate( - TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); template using TensorCoordinateStep_t = decltype(make_tensor_coordinate_step( - TensorDesc{}, MultiIndex>::GetNumOfDimension()>{})); + TensorDesc{}, MultiIndex::GetNumOfDimension()>{})); } // namespace ck #endif diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp index 03f889649e..5cc2f2393e 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp @@ -110,13 +110,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3 const BThreadBuffer& b_thread_buf, CThreadBuffer& c_thread_buf) const { - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; diff --git a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp index a925a5cd68..8b75381026 100644 --- a/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_contraction_dlops.hpp @@ -55,19 +55,16 @@ struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 CBuffer& c_buf, COriginIdx) { - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -157,19 +154,16 @@ struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_ CBuffer& c_buf, COriginIdx) { - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; diff --git a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp index 015ad675fb..f6c15fd85a 100644 --- a/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp @@ -41,19 +41,16 @@ struct ThreadwiseGemmDlops_km_kn_mn_v3 CDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert( - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value && - is_known_at_compile_time>>::value, - "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp index 0c7aa978a7..20e9a5b366 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_set.hpp @@ -30,11 +30,11 @@ struct ThreadwiseTensorSliceSet_v1 static_assert(Buffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert(is_known_at_compile_time>>::value, + static_assert(is_known_at_compile_time>::value, "wrong! OriginIdx need to be known at compile-time"); // Desc is known at compile-time - constexpr auto desc = remove_cv_t>{}; + constexpr auto desc = remove_cvref_t{}; // OriginIdx is known at compile-time constexpr auto origin_idx = to_multi_index(OriginIdx{}); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index e38dbbc8b5..d5c77f4a54 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -95,18 +95,13 @@ struct ThreadwiseTensorSliceTransfer_v1r3 static_assert(SrcDesc::IsKnownAtCompileTime(), "wrong! SrcDesc need to known at compile-time"); - static_assert( - is_known_at_compile_time>>::value, - "wrong! SrcSliceOrigin need to known at compile-time"); + static_assert(is_known_at_compile_time>::value, + "wrong! SrcSliceOrigin need to known at compile-time"); static_assert(SrcBuffer::IsStaticBuffer(), "wrong! SrcBuffer need to be StaticBuffer"); - // static_assert(is_same>, - // remove_cv_t>>::value, - //"wrong! SrcBuffer data type is wrong"); - // SrcDesc and src_slice_origin_idx are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; constexpr auto src_slice_origin_idx = to_multi_index(SrcSliceOriginIdx{}); constexpr auto I0 = Number<0>{}; @@ -421,16 +416,15 @@ struct ThreadwiseTensorSliceTransfer_v2 static_assert(DstDesc::IsKnownAtCompileTime(), "wrong! DstDesc need to known at compile-time"); - static_assert( - is_known_at_compile_time>>::value, - "wrong! DstSliceOrigin need to known at compile-time"); + static_assert(is_known_at_compile_time>::value, + "wrong! DstSliceOrigin need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - "wrong! inconsistent type"); + static_assert( + is_same, remove_cvref_t>::value && + "wrong! inconsistent type"); // DstDesc and dst_slice_origin_idx are known at compile-time - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto dst_desc = remove_cvref_t{}; constexpr auto dst_slice_origin_idx = DstSliceOriginIdx{}; constexpr auto I0 = Number<0>{}; @@ -742,9 +736,9 @@ struct ThreadwiseTensorSliceTransfer_v3 SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -899,9 +893,9 @@ struct ThreadwiseTensorSliceTransfer_v3 DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); constexpr auto I0 = Number<0>{}; constexpr auto I1 = Number<1>{}; @@ -1315,24 +1309,21 @@ struct ThreadwiseTensorSliceTransfer_v4 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert( - is_known_at_compile_time< - remove_cv_t>>::value && - is_known_at_compile_time>>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp index ccac4b7b44..bbdaa5fa2b 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer_v2.hpp @@ -80,9 +80,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 SrcBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer and SrcData data type are inconsistent"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer and SrcData data type are inconsistent"); // tensor descriptor for src_vector constexpr auto src_vector_tensor_lengths = SrcVectorTensorLengths{}; @@ -248,9 +248,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1 DstBuffer::GetAddressSpace() == AddressSpaceEnum_t::Lds, "wrong!"); - static_assert(is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); // tensor descriptor for dst_vector constexpr auto dst_vector_tensor_lengths = DstVectorTensorLengths{}; @@ -669,24 +669,21 @@ struct ThreadwiseTensorSliceTransfer_v4r1 static_assert(SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), "wrong! SrcDesc and DstDesc need to known at compile-time"); - static_assert(is_same>, - remove_cv_t>>::value && - is_same>, - remove_cv_t>>::value, - "wrong! SrcBuffer or DstBuffer data type is wrong"); + static_assert( + is_same, remove_cvref_t>::value && + is_same, remove_cvref_t>::value, + "wrong! SrcBuffer or DstBuffer data type is wrong"); static_assert(DstBuffer::IsStaticBuffer(), "wrong! DstBuffer need to be StaticBuffer"); - static_assert( - is_known_at_compile_time< - remove_cv_t>>::value && - is_known_at_compile_time>>::value, - "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " - "at compile-time"); + static_assert(is_known_at_compile_time>::value && + is_known_at_compile_time>::value, + "wrong! SrcOriginToRefDistance and DstOriginToRefDistance need to be known " + "at compile-time"); // SrcDesc and DstDesc are known at compile-time - constexpr auto src_desc = remove_cv_t>{}; - constexpr auto dst_desc = remove_cv_t>{}; + constexpr auto src_desc = remove_cvref_t{}; + constexpr auto dst_desc = remove_cvref_t{}; // SrcOriginToRefDisttance and DstOriginToRefDistance are known at compile-time constexpr auto src_ref_to_origin_disp_idx = to_multi_index(SrcRefToOriginDisplacement{}); diff --git a/composable_kernel/include/utility/amd_buffer_addressing.hpp b/composable_kernel/include/utility/amd_buffer_addressing.hpp index b7fd4bc409..3df53bda44 100644 --- a/composable_kernel/include/utility/amd_buffer_addressing.hpp +++ b/composable_kernel/include/utility/amd_buffer_addressing.hpp @@ -225,13 +225,49 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w index_t src_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) + if constexpr(is_same::value) + { + // use fp32 load to mimic fp64 load + if constexpr(N == 1) + { + const float2_t tmp = llvm_amdgcn_raw_buffer_load_fp32x2( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); + } + else if constexpr(N == 2) + { + const float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + return as_type(tmp); + } + else if constexpr(N == 4) + { + const float4_t f32_0 = llvm_amdgcn_raw_buffer_load_fp32x4( + src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); + + const float4_t f32_1 = + llvm_amdgcn_raw_buffer_load_fp32x4(src_wave_buffer_resource, + src_thread_addr_offset, + src_wave_addr_offset + 4 * sizeof(float), + 0); + vector_type tmp; + + tmp.AsType()(Number<0>{}) = as_type(f32_0); + tmp.AsType()(Number<1>{}) = as_type(f32_1); + + return tmp.AsType()(Number<0>{}); + } + } + else if constexpr(is_same::value) { if constexpr(N == 1) { @@ -283,25 +319,11 @@ __device__ typename vector_type::type amd_buffer_load_impl(int32x4_t src_w } else if constexpr(N == 8) { -#if 0 - vector_type tmp; - - tmp.AsType()(Number<0>{}) = llvm_amdgcn_raw_buffer_load_fp16x4( - src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); - - tmp.AsType()(Number<1>{}) = - llvm_amdgcn_raw_buffer_load_fp16x4(src_wave_buffer_resource, - src_thread_addr_offset, - src_wave_addr_offset + 4 * sizeof(half_t), - 0); - - return tmp.AsType()(Number<0>{}); -#else + // use fp32 load to mimic fp16 load float4_t tmp = llvm_amdgcn_raw_buffer_load_fp32x4( src_wave_buffer_resource, src_thread_addr_offset, src_wave_addr_offset, 0); return as_type(tmp); -#endif } } else if constexpr(is_same::value) @@ -433,13 +455,34 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src index_t dst_wave_addr_offset) { static_assert( - (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2)) || + (is_same::value && (N == 1 || N == 2 || N == 4)) || + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)) || (is_same::value && (N == 1 || N == 2 || N == 4)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)) || - (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8)), + (is_same::value && (N == 1 || N == 2 || N == 4 || N == 8 || N == 16)), "wrong! not implemented"); - if constexpr(is_same::value) + if constexpr(is_same::value) + { + // use fp32 store to mimic fp64 store + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp32x2(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp32x4(as_type(src_thread_data), + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + } + else if constexpr(is_same::value) { if constexpr(N == 1) { @@ -466,6 +509,49 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src 0); } } + else if constexpr(is_same::value) + { + if constexpr(N == 1) + { + llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 2) + { + llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 4) + { + llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + } + else if constexpr(N == 8) + { + vector_type tmp{src_thread_data}; + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset, + 0); + + llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], + dst_wave_buffer_resource, + dst_thread_addr_offset, + dst_wave_addr_offset + 4 * sizeof(half_t), + 0); + } + } else if constexpr(is_same::value) { if constexpr(N == 1) @@ -552,49 +638,6 @@ __device__ void amd_buffer_store_impl(const typename vector_type::type src 0); } } - else if constexpr(is_same::value) - { - if constexpr(N == 1) - { - llvm_amdgcn_raw_buffer_store_fp16(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 2) - { - llvm_amdgcn_raw_buffer_store_fp16x2(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 4) - { - llvm_amdgcn_raw_buffer_store_fp16x4(src_thread_data, - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - } - else if constexpr(N == 8) - { - vector_type tmp{src_thread_data}; - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<0>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset, - 0); - - llvm_amdgcn_raw_buffer_store_fp16x4(tmp.AsType()[Number<1>{}], - dst_wave_buffer_resource, - dst_thread_addr_offset, - dst_wave_addr_offset + 4 * sizeof(half_t), - 0); - } - } } template @@ -720,7 +763,7 @@ __device__ void amd_buffer_atomic_add_impl(const typename vector_type::typ } // buffer_load requires: -// 1) p_src_wave must be in global memory space +// 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template @@ -754,7 +797,7 @@ amd_buffer_load_invalid_element_return_return_zero(const T* p_src_wave, } // buffer_load requires: -// 1) p_src_wave must be in global memory space +// 1) p_src_wave must point to global memory space // 2) p_src_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template @@ -782,7 +825,7 @@ amd_buffer_load_invalid_element_return_customized_value(const T* p_src_wave, } // buffer_store requires: -// 1) p_dst_wave must be global memory +// 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template @@ -816,7 +859,7 @@ __device__ void amd_buffer_store(const typename vector_type_maker::type::t } // buffer_atomic_add requires: -// 1) p_dst_wave must be global memory +// 1) p_dst_wave must point to global memory // 2) p_dst_wave must be a wavewise pointer. // It is user's responsibility to make sure that is true. template diff --git a/composable_kernel/include/utility/array.hpp b/composable_kernel/include/utility/array.hpp index 7271094d39..911cefd057 100644 --- a/composable_kernel/include/utility/array.hpp +++ b/composable_kernel/include/utility/array.hpp @@ -48,7 +48,7 @@ struct Array template __host__ __device__ constexpr auto make_array(X&& x, Xs&&... xs) { - using data_type = remove_cv_t>; + using data_type = remove_cvref_t; return Array{{std::forward(x), std::forward(xs)...}}; } diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp index 24a2190e84..bfaac8a939 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/composable_kernel/include/utility/data_type.hpp @@ -73,6 +73,13 @@ struct scalar_type> }; // +template <> +struct scalar_type +{ + using type = double; + static constexpr index_t vector_size = 1; +}; + template <> struct scalar_type { @@ -864,6 +871,10 @@ struct vector_type } }; +// fp64 +using double2_t = typename vector_type::type; +using double4_t = typename vector_type::type; + // fp32 using float2_t = typename vector_type::type; using float4_t = typename vector_type::type; diff --git a/composable_kernel/include/utility/dynamic_buffer.hpp b/composable_kernel/include/utility/dynamic_buffer.hpp index a875afd9be..7029bd850f 100644 --- a/composable_kernel/include/utility/dynamic_buffer.hpp +++ b/composable_kernel/include/utility/dynamic_buffer.hpp @@ -39,18 +39,15 @@ struct DynamicBuffer } template >>::type, - typename scalar_type>>::type>::value, - bool>::type = false> + typename enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> __host__ __device__ constexpr auto Get(index_t i, bool is_valid_element) const { // X contains multiple T - constexpr index_t scalar_per_t_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); @@ -67,15 +64,14 @@ struct DynamicBuffer if constexpr(InvalidElementUseNumericalZeroValue) { - return amd_buffer_load_invalid_element_return_return_zero< - remove_cv_t>, - t_per_x>(p_data_, i, is_valid_element, element_space_size_); + return amd_buffer_load_invalid_element_return_return_zero, + t_per_x>( + p_data_, i, is_valid_element, element_space_size_); } else { - return amd_buffer_load_invalid_element_return_customized_value< - remove_cv_t>, - t_per_x>( + return amd_buffer_load_invalid_element_return_customized_value, + t_per_x>( p_data_, i, is_valid_element, element_space_size_, invalid_element_value_); } } @@ -94,18 +90,15 @@ struct DynamicBuffer } template >>::type, - typename scalar_type>>::type>::value, - bool>::type = false> + typename enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> __host__ __device__ void Set(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); @@ -115,7 +108,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_ADDRESSING constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - amd_buffer_store>, t_per_x>( + amd_buffer_store, t_per_x>( x, p_data_, i, is_valid_element, element_space_size_); #else if(is_valid_element) @@ -136,70 +129,65 @@ struct DynamicBuffer // ISA, so I try to let compiler emit IR "store" which would be lower to // ds_write_b128 // TODO: remove this after compiler fix - if constexpr(is_same>>::type, - int8_t>::value) + if constexpr(is_same>::type, int8_t>::value) { - static_assert( - (is_same>, int8_t>::value && - is_same>, int8_t>::value) || - (is_same>, int8_t>::value && - is_same>, int8x2_t>::value) || - (is_same>, int8_t>::value && - is_same>, int8x4_t>::value) || - (is_same>, int8x4_t>::value && - is_same>, int8x4_t>::value) || - (is_same>, int8x8_t>::value && - is_same>, int8x8_t>::value) || - (is_same>, int8x16_t>::value && - is_same>, int8x16_t>::value), - "wrong! not implemented for this combination, please add " - "implementation"); + static_assert((is_same, int8_t>::value && + is_same, int8_t>::value) || + (is_same, int8_t>::value && + is_same, int8x2_t>::value) || + (is_same, int8_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x4_t>::value && + is_same, int8x4_t>::value) || + (is_same, int8x8_t>::value && + is_same, int8x8_t>::value) || + (is_same, int8x16_t>::value && + is_same, int8x16_t>::value), + "wrong! not implemented for this combination, please add " + "implementation"); - if constexpr(is_same>, int8_t>::value && - is_same>, int8_t>::value) + if constexpr(is_same, int8_t>::value && + is_same, int8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, int8_t>::value && - is_same>, int8x2_t>::value) + else if constexpr(is_same, int8_t>::value && + is_same, int8x2_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, int8_t>::value && - is_same>, int8x4_t>::value) + else if constexpr(is_same, int8_t>::value && + is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x4_t>::value && - is_same>, int8x4_t>::value) + else if constexpr(is_same, int8x4_t>::value && + is_same, int8x4_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x8_t>::value && - is_same>, int8x8_t>::value) + else if constexpr(is_same, int8x8_t>::value && + is_same, int8x8_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix *c_style_pointer_cast(&p_data_[i]) = *c_style_pointer_cast(&x); } - else if constexpr(is_same>, - int8x16_t>::value && - is_same>, int8x16_t>::value) + else if constexpr(is_same, int8x16_t>::value && + is_same, int8x16_t>::value) { // HACK: cast pointer of x is bad // TODO: remove this after compiler fix @@ -224,18 +212,15 @@ struct DynamicBuffer } template >>::type, - typename scalar_type>>::type>::value, - bool>::type = false> + typename enable_if>::type, + typename scalar_type>::type>::value, + bool>::type = false> __host__ __device__ void AtomicAdd(index_t i, bool is_valid_element, const X& x) { // X contains multiple T - constexpr index_t scalar_per_t_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_t_vector = scalar_type>::vector_size; - constexpr index_t scalar_per_x_vector = - scalar_type>>::vector_size; + constexpr index_t scalar_per_x_vector = scalar_type>::vector_size; static_assert(scalar_per_x_vector % scalar_per_t_vector == 0, "wrong! X need to be multiple T"); @@ -245,7 +230,7 @@ struct DynamicBuffer #if CK_USE_AMD_BUFFER_ADDRESSING constexpr index_t t_per_x = scalar_per_x_vector / scalar_per_t_vector; - amd_buffer_atomic_add>, t_per_x>( + amd_buffer_atomic_add, t_per_x>( x, p_data_, i, is_valid_element, element_space_size_); #else if(is_valid_element) @@ -266,9 +251,14 @@ __host__ __device__ constexpr auto make_dynamic_buffer(T* p, ElementSpaceSize el return DynamicBuffer{p, element_space_size}; } -template +template < + AddressSpaceEnum_t BufferAddressSpace, + typename T, + typename ElementSpaceSize, + typename X, + typename enable_if, remove_cvref_t>::value, bool>::type = false> __host__ __device__ constexpr auto -make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, T invalid_element_value) +make_dynamic_buffer(T* p, ElementSpaceSize element_space_size, X invalid_element_value) { return DynamicBuffer{ p, element_space_size, invalid_element_value}; diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index ee96a8b435..70f4d77d87 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -159,7 +159,7 @@ struct Tuple : detail::TupleImpl __host__ __device__ constexpr auto make_tuple(Xs&&... xs) { - return Tuple>...>(std::forward(xs)...); + return Tuple...>(std::forward(xs)...); } } // namespace ck diff --git a/composable_kernel/include/utility/tuple_helper.hpp b/composable_kernel/include/utility/tuple_helper.hpp index 9499a3596c..55a79d2594 100644 --- a/composable_kernel/include/utility/tuple_helper.hpp +++ b/composable_kernel/include/utility/tuple_helper.hpp @@ -14,9 +14,7 @@ struct is_known_at_compile_time> return container_reduce( Tuple{}, [](auto x, bool r) { - return is_known_at_compile_time< - remove_cv_t>>::value & - r; + return is_known_at_compile_time>::value & r; }, true); } diff --git a/composable_kernel/include/utility/type.hpp b/composable_kernel/include/utility/type.hpp index b7902ad496..89a2bdbde6 100644 --- a/composable_kernel/include/utility/type.hpp +++ b/composable_kernel/include/utility/type.hpp @@ -22,6 +22,9 @@ using remove_reference_t = typename std::remove_reference::type; template using remove_cv_t = typename std::remove_cv::type; +template +using remove_cvref_t = remove_cv_t>; + template inline constexpr bool is_pointer_v = std::is_pointer::value; diff --git a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp index c1208ac3cb..71239e0ecc 100644 --- a/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp +++ b/composable_kernel/src/kernel_wrapper/convolution_forward_implicit_gemm_v6r1_dlops_nchw_kcyx_nkhw.cpp @@ -374,13 +374,8 @@ extern "C" __global__ void CGridDesc_GM10_BM0_BM1_GN10_BN0_BN1{}, CGridBlockCluster_BlockId_To_GM10_GN10{})); - const auto desc_tuple = *reinterpret_cast( -#pragma clang diagnostic push -#pragma clang diagnostic ignored "-Wold-style-cast" - // TODO: how to cast? - (const void*)p_desc_tuple -#pragma clang diagnostic pop - ); + const auto desc_tuple = + *reinterpret_cast(cast_pointer_to_generic_address_space(p_desc_tuple)); const auto a_grid_desc_gk0_gm0_gm10_gm11_gk1 = desc_tuple[I0]; const auto b_grid_desc_gk0_gn0_gn10_gn11_gk1 = desc_tuple[I1]; From 627d8ef35a6da8ad268b5197e3045ccdfb4ac684 Mon Sep 17 00:00:00 2001 From: ltqin Date: Tue, 31 Aug 2021 11:49:17 +0800 Subject: [PATCH 07/15] Backward weight v4r4r2 with xdlops (#18) * start * modify transformat * modify device convolutiion * modify host * added host conv bwd and wrw * remove bwd, seperate wrw * clean * hacall k to zero * out log * fixed * fixed * change to (out in wei) * input hack * hack to out * format * fix by comments * change wei hacks(wei transform has not merge) * fix program once issue * fix review comment * fix vector load issue * tweak Co-authored-by: ltqin Co-authored-by: Jing Zhang Co-authored-by: Chao Liu --- ...lution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp | 129 ++++++++ host/driver_offline/CMakeLists.txt | 3 + ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 24 +- ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 228 +++++++++++++ ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 85 +++-- ...icit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp | 302 ------------------ ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 32 +- .../src/conv_bwd_driver_offline.cpp | 48 +-- .../src/conv_fwd_driver_offline.cpp | 80 ++--- .../src/conv_wrw_driver_offline.cpp | 281 ++++++++++++++++ ..._conv_wrw.hpp => host_conv_bwd_weight.hpp} | 0 11 files changed, 777 insertions(+), 435 deletions(-) create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp delete mode 100644 host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/src/conv_wrw_driver_offline.cpp rename host/host_tensor/include/{host_conv_wrw.hpp => host_conv_bwd_weight.hpp} (100%) diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..949f044b7d --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp @@ -0,0 +1,129 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmk_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // output tensor + const auto out_gemmk_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index fec11e99af..8dec70d03f 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -13,9 +13,12 @@ include_directories(BEFORE set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) +set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) +add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) +target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 7196f3c179..8f49473563 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -208,20 +208,20 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: Gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: Gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: Gemmk1 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmn - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: GemmK1 // clang-format off constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e97bc9c1c7 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -0,0 +1,228 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 1, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false>(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index dc4f5eafb6..d65ecadb4d 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -47,7 +47,7 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); -#if 1 +#if 0 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -74,6 +74,34 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif @@ -92,36 +120,39 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto out_gemmm_gemmn_grid_desc = descs[I2]; // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); + constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 692751bfb3..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,302 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r3.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r3_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - constexpr auto I4 = Number<4>{}; - constexpr auto I5 = Number<5>{}; - constexpr auto I6 = Number<6>{}; - constexpr auto I7 = Number<7>{}; - constexpr auto I8 = Number<8>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 128; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 256; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 4; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 32; - constexpr index_t GemmNPerWave = 32; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1, 7, 5, 4, 6>, - 6, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat - >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Hi = in_n_hi_wi_c_lengths[I1]; - const auto Wi = in_n_hi_wi_c_lengths[I2]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 5ff8dfb665..52432664de 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -250,22 +250,22 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index 67cea94813..4e93ada859 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -41,7 +41,7 @@ int main(int argc, char* argv[]) // dynamic mode if(argc != 22) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); exit(1); } @@ -79,7 +79,7 @@ int main(int argc, char* argv[]) // static mode if(argc < 7) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); exit(1); } @@ -90,28 +90,28 @@ int main(int argc, char* argv[]) const bool do_log = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[6]); - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr auto N = Number<128>{}; + constexpr auto C = Number<192>{}; + constexpr auto Hi = Number<71>{}; + constexpr auto Wi = Number<71>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; - const index_t conv_stride_h = 2; - const index_t conv_stride_w = 2; - const index_t conv_dilation_h = 1; - const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 1; - const index_t in_left_pad_w = 1; - const index_t in_right_pad_h = 1; - const index_t in_right_pad_w = 1; + constexpr auto conv_stride_h = I2; + constexpr auto conv_stride_w = I2; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif #if 0 @@ -119,9 +119,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 21acb35732..34d7247f3c 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -19,7 +19,7 @@ #include "device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" #include "device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" -#define USE_MODE 1 +#define USE_DYNAMIC_MODE 1 #define USE_CONV_FWD_V4R4_NCHW 0 #define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V6R1_NCHW 0 @@ -49,11 +49,11 @@ int main(int argc, char* argv[]) constexpr auto I5 = Number<5>{}; constexpr auto I6 = Number<6>{}; -#if USE_MODE +#if USE_DYNAMIC_MODE // dynamic mode if(argc != 22) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); exit(1); } @@ -91,7 +91,7 @@ int main(int argc, char* argv[]) // static mode if(argc < 7) { - printf("arg1 to 5: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); exit(1); } @@ -102,28 +102,28 @@ int main(int argc, char* argv[]) const bool do_log = std::stoi(argv[5]); const int nrepeat = std::stoi(argv[6]); - constexpr index_t N = 128; - constexpr index_t C = 192; - constexpr index_t Hi = 71; - constexpr index_t Wi = 71; - constexpr index_t K = 256; - constexpr index_t Y = 3; - constexpr index_t X = 3; + constexpr auto N = Number<128>{}; + constexpr auto C = Number<192>{}; + constexpr auto Hi = Number<71>{}; + constexpr auto Wi = Number<71>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; - const index_t conv_stride_h = 2; - const index_t conv_stride_w = 2; - const index_t conv_dilation_h = 1; - const index_t conv_dilation_w = 1; - const index_t in_left_pad_h = 1; - const index_t in_left_pad_w = 1; - const index_t in_right_pad_h = 1; - const index_t in_right_pad_w = 1; + constexpr auto conv_stride_h = I2; + constexpr auto conv_stride_w = I2; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; - const index_t YEff = (Y - 1) * conv_dilation_h + 1; - const index_t XEff = (X - 1) * conv_dilation_w + 1; + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; - const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; - const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif #if 0 @@ -131,9 +131,9 @@ int main(int argc, char* argv[]) using acc_data_t = float; using out_data_t = float; #elif 1 - using in_data_t = half_t; - using acc_data_t = float; - using out_data_t = half_t; + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; #elif 1 using in_data_t = int8_t; using acc_data_t = int32_t; @@ -228,7 +228,6 @@ int main(int argc, char* argv[]) } auto f_make_for_device_nchw = [&]() { -#if USE_MODE const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); const auto wei_lengths_dev = make_tuple(K, C, Y, X); const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); @@ -236,19 +235,6 @@ int main(int argc, char* argv[]) const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif return make_tuple(in_lengths_dev, wei_lengths_dev, @@ -260,7 +246,6 @@ int main(int argc, char* argv[]) }; auto f_make_for_device_nhwc = [&]() { -#if USE_MODE const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); const auto wei_lengths_dev = make_tuple(K, Y, X, C); const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); @@ -268,19 +253,6 @@ int main(int argc, char* argv[]) const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); -#else - const auto in_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto wei_lengths_dev = make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto out_lengths_dev = - make_tuple(Number{}, Number{}, Number{}, Number{}); - const auto conv_strides_dev = make_tuple(Number{}, Number{}); - const auto conv_dilations_dev = - make_tuple(Number{}, Number{}); - const auto in_left_pads_dev = make_tuple(Number{}, Number{}); - const auto in_right_pads_dev = - make_tuple(Number{}, Number{}); -#endif return make_tuple(in_lengths_dev, wei_lengths_dev, diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp new file mode 100644 index 0000000000..13c73abf30 --- /dev/null +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "conv_common.hpp" +#include "host_conv_bwd_weight.hpp" +#include "device_tensor.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" + +#define USE_DYNAMIC_MODE 1 +#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 + +enum ConvBackwardWeightAlgo +{ + V4R4R2XDLNCHW, +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + constexpr auto I4 = Number<4>{}; + constexpr auto I5 = Number<5>{}; + constexpr auto I6 = Number<6>{}; + +#if USE_DYNAMIC_MODE + // dynamic mode + if(argc != 22) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t N = std::stoi(argv[7]); + const index_t K = std::stoi(argv[8]); + const index_t C = std::stoi(argv[9]); + const index_t Y = std::stoi(argv[10]); + const index_t X = std::stoi(argv[11]); + const index_t Hi = std::stoi(argv[12]); + const index_t Wi = std::stoi(argv[13]); + + const index_t conv_stride_h = std::stoi(argv[14]); + const index_t conv_stride_w = std::stoi(argv[15]); + const index_t conv_dilation_h = std::stoi(argv[16]); + const index_t conv_dilation_w = std::stoi(argv[17]); + const index_t in_left_pad_h = std::stoi(argv[18]); + const index_t in_left_pad_w = std::stoi(argv[19]); + const index_t in_right_pad_h = std::stoi(argv[20]); + const index_t in_right_pad_w = std::stoi(argv[21]); + + const index_t YEff = (Y - 1) * conv_dilation_h + 1; + const index_t XEff = (X - 1) * conv_dilation_w + 1; + + const index_t Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + 1; + const index_t Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; +#else + // static mode + if(argc < 7) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + exit(1); + } + + const ConvTensorLayout layout = static_cast(std::stoi(argv[1])); + const ConvBackwardWeightAlgo algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + constexpr auto N = Number<128>{}; + constexpr auto C = Number<128>{}; + constexpr auto Hi = Number<14>{}; + constexpr auto Wi = Number<14>{}; + constexpr auto K = Number<256>{}; + constexpr auto Y = Number<3>{}; + constexpr auto X = Number<3>{}; + + constexpr auto conv_stride_h = I1; + constexpr auto conv_stride_w = I1; + constexpr auto conv_dilation_h = I1; + constexpr auto conv_dilation_w = I1; + constexpr auto in_left_pad_h = I1; + constexpr auto in_left_pad_w = I1; + constexpr auto in_right_pad_h = I1; + constexpr auto in_right_pad_w = I1; + + constexpr auto YEff = (Y - I1) * conv_dilation_h + I1; + constexpr auto XEff = (X - I1) * conv_dilation_w + I1; + + constexpr auto Ho = (Hi + in_left_pad_h + in_right_pad_h - YEff) / conv_stride_h + I1; + constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; +#endif + +#if 1 + using in_data_t = float; + using acc_data_t = float; + using out_data_t = float; +#elif 1 + using in_data_t = half_t; + using acc_data_t = float; + using out_data_t = half_t; +#elif 1 + using in_data_t = int8_t; + using acc_data_t = int32_t; + using out_data_t = int8_t; +#endif + + std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); + + if(layout == ConvTensorLayout::NCHW) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(C); + in_lengths_host[2] = static_cast(Hi); + in_lengths_host[3] = static_cast(Wi); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(C); + wei_lengths_host[2] = static_cast(Y); + wei_lengths_host[3] = static_cast(X); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(K); + out_lengths_host[2] = static_cast(Ho); + out_lengths_host[3] = static_cast(Wo); + } + else if(layout == ConvTensorLayout::NHWC) + { + in_lengths_host[0] = static_cast(N); + in_lengths_host[1] = static_cast(Hi); + in_lengths_host[2] = static_cast(Wi); + in_lengths_host[3] = static_cast(C); + wei_lengths_host[0] = static_cast(K); + wei_lengths_host[1] = static_cast(Y); + wei_lengths_host[2] = static_cast(X); + wei_lengths_host[3] = static_cast(C); + out_lengths_host[0] = static_cast(N); + out_lengths_host[1] = static_cast(Ho); + out_lengths_host[2] = static_cast(Wo); + out_lengths_host[3] = static_cast(K); + } + else + { + std::runtime_error("wrong! not implemented"); + } + + Tensor in(in_lengths_host); + Tensor wei_device(wei_lengths_host); + Tensor wei_host(wei_lengths_host); + Tensor out(out_lengths_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(in.mDesc, std::cout << "in: "); + ostream_HostTensorDescriptor(wei_host.mDesc, std::cout << "wei: "); + ostream_HostTensorDescriptor(out.mDesc, std::cout << "out: "); + print_array("InLeftPads", make_tuple(in_left_pad_h, in_left_pad_w)); + print_array("InRightPads", make_tuple(in_right_pad_h, in_right_pad_w)); + print_array("ConvStrides", make_tuple(conv_stride_h, conv_stride_w)); + print_array("ConvDilations", make_tuple(conv_dilation_h, conv_dilation_w)); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + in.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + in.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + out.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 5: + in.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + out.GenerateTensorValue(GeneratorTensor_3{-0.1, 0.1}, num_thread); + break; + default: + in.GenerateTensorValue(GeneratorTensor_2{1, 5}, num_thread); + + auto gen_out = [](auto... is) { + return GeneratorTensor_2{1, 5}(is...) * GeneratorTensor_Checkboard{}(is...); + }; + out.GenerateTensorValue(gen_out, num_thread); + } + + auto f_make_for_device_nchw = [&]() { + const auto in_lengths_dev = make_tuple(N, C, Hi, Wi); + const auto wei_lengths_dev = make_tuple(K, C, Y, X); + const auto out_lengths_dev = make_tuple(N, K, Ho, Wo); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + +#if USE_CONV_WRW_V4R4R2_XDL_NCHW + if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + nrepeat); + } +#endif + + if(do_verification) + { + host_direct_convolution_backward_weights(out, + in, + wei_host, + make_tuple(conv_stride_h, conv_stride_w), + make_tuple(conv_dilation_h, conv_dilation_w), + make_tuple(in_left_pad_h, in_left_pad_w), + make_tuple(in_right_pad_h, in_right_pad_w), + layout); + + check_error(wei_host, wei_device); + + if(do_log) + { + LogRangeAsType(std::cout << "out: ", out.mData, ",") << std::endl; + LogRangeAsType(std::cout << "in : ", in.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei_device: ", wei_device.mData, ",") << std::endl; + LogRangeAsType(std::cout << "wei_host : ", wei_host.mData, ",") << std::endl; + } + } +} diff --git a/host/host_tensor/include/host_conv_wrw.hpp b/host/host_tensor/include/host_conv_bwd_weight.hpp similarity index 100% rename from host/host_tensor/include/host_conv_wrw.hpp rename to host/host_tensor/include/host_conv_bwd_weight.hpp From 19613902b58d402c883e033be37ba8a647bcb5a6 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 5 Sep 2021 12:41:28 -0500 Subject: [PATCH 08/15] GEMM driver and kernel (#29) * add gemm driver * tweak * add gemm kernel: mk_kn_mn and km_kn_mn * tweak * add GEMM km_nk_mn * fix comment --- host/driver_offline/CMakeLists.txt | 3 + .../include/device_gemm_xdlops_km_kn_mn.hpp | 219 +++++++++++++ .../include/device_gemm_xdlops_km_nk_mn.hpp | 219 +++++++++++++ .../include/device_gemm_xdlops_mk_kn_mn.hpp | 219 +++++++++++++ .../include/device_gemm_xdlops_mk_nk_mn.hpp | 275 ++++++++++++++++ .../src/gemm_driver_offline.cpp | 294 ++++++++++++++++++ host/host_tensor/include/gemm_common.hpp | 12 + host/host_tensor/include/host_gemm.hpp | 87 ++++++ script/run.sh | 23 +- 9 files changed, 1345 insertions(+), 6 deletions(-) create mode 100644 host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp create mode 100644 host/driver_offline/src/gemm_driver_offline.cpp create mode 100644 host/host_tensor/include/gemm_common.hpp create mode 100644 host/host_tensor/include/host_gemm.hpp diff --git a/host/driver_offline/CMakeLists.txt b/host/driver_offline/CMakeLists.txt index 8dec70d03f..a3b3613293 100644 --- a/host/driver_offline/CMakeLists.txt +++ b/host/driver_offline/CMakeLists.txt @@ -14,11 +14,14 @@ include_directories(BEFORE set(CONV_FWD_DRIVER_OFFLINE_SOURCE src/conv_fwd_driver_offline.cpp) set(CONV_BWD_DRIVER_OFFLINE_SOURCE src/conv_bwd_driver_offline.cpp) set(CONV_WRW_DRIVER_OFFLINE_SOURCE src/conv_wrw_driver_offline.cpp) +set(GEMM_DRIVER_OFFLINE_SOURCE src/gemm_driver_offline.cpp) add_executable(conv_fwd_driver_offline ${CONV_FWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_bwd_driver_offline ${CONV_BWD_DRIVER_OFFLINE_SOURCE}) add_executable(conv_wrw_driver_offline ${CONV_WRW_DRIVER_OFFLINE_SOURCE}) +add_executable(gemm_driver_offline ${GEMM_DRIVER_OFFLINE_SOURCE}) target_link_libraries(conv_fwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_bwd_driver_offline PRIVATE host_tensor) target_link_libraries(conv_wrw_driver_offline PRIVATE host_tensor) +target_link_libraries(gemm_driver_offline PRIVATE host_tensor) diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp new file mode 100644 index 0000000000..d9169649e6 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, + const BDesc& b_k_n_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_k_m, + const Tensor& b_k_n, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_k_n_grid_desc.GetLength(I1); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k_m_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_k_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp new file mode 100644 index 0000000000..90e258d581 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, + const BDesc& b_n_k_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_k_m, + const Tensor& b_n_k, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_k_m_grid_desc.GetLength(I0); + const auto M = a_k_m_grid_desc.GetLength(I1); + const auto N = b_n_k_grid_desc.GetLength(I0); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k_m_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(M)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_n_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp new file mode 100644 index 0000000000..ab235d97e7 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -0,0 +1,219 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, + const BDesc& b_k_n_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_m_k_grid_desc.GetLength(I1); + const auto M = a_m_k_grid_desc.GetLength(I0); + const auto N = b_k_n_grid_desc.GetLength(I1); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_m_k_grid_desc, + make_tuple(make_pass_through_transform(M), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_k_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), + make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp new file mode 100644 index 0000000000..c68442d127 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp @@ -0,0 +1,275 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, + const BDesc& b_n_k_grid_desc, + const CDesc& c_m_n_grid_desc, + const Tensor& a_m_k, + const Tensor& b_n_k, + Tensor& c_m_n, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_m_n_device_buf.ToDevice(c_m_n.mData.data()); + +#if 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#endif + + const auto K = a_m_k_grid_desc.GetLength(I1); + const auto M = a_m_k_grid_desc.GetLength(I0); + const auto N = b_n_k_grid_desc.GetLength(I0); + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_m_k_grid_desc, + make_tuple(make_pass_through_transform(M), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + const auto b_k0_n_k1_grid_desc = + transform_tensor_descriptor(b_n_k_grid_desc, + make_tuple(make_pass_through_transform(N), + make_unmerge_transform(make_tuple(K0, K1Number))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0>{}, // 1+: N + Sequence<0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0>{}, // 1-: N + Sequence<0, 0, 0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + 7, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_m_n_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_m_n_device_buf.FromDevice(c_m_n.mData.data()); +} diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp new file mode 100644 index 0000000000..42c69ff6a2 --- /dev/null +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -0,0 +1,294 @@ +#include +#include +#include +#include +#include +#include +#include "config.hpp" +#include "print.hpp" +#include "device.hpp" +#include "host_tensor.hpp" +#include "host_tensor_generator.hpp" +#include "gemm_common.hpp" +#include "host_gemm.hpp" +#include "device_tensor.hpp" +#include "device_gemm_xdlops_mk_kn_mn.hpp" +#include "device_gemm_xdlops_mk_nk_mn.hpp" +#include "device_gemm_xdlops_km_kn_mn.hpp" +#include "device_gemm_xdlops_km_nk_mn.hpp" + +#define USE_GEMM_XDL_MK_KN_MN 1 +#define USE_GEMM_XDL_MK_NK_MN 1 +#define USE_GEMM_XDL_KM_KN_MN 1 +#define USE_GEMM_XDL_KM_NK_MN 1 + +enum GemmAlgo +{ + Xdl_MK_KN_MN, // 0 + Xdl_MK_NK_MN, // 1 + Xdl_KM_KN_MN, // 2 + Xdl_KM_NK_MN, // 3 +}; + +int main(int argc, char* argv[]) +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + + // dynamic mode + if(argc != 10) + { + printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); + printf("rest: M, N, K\n"); + exit(1); + } + + const auto layout = static_cast(std::stoi(argv[1])); + const auto algo = static_cast(std::stoi(argv[2])); + const bool do_verification = std::stoi(argv[3]); + const int init_method = std::stoi(argv[4]); + const bool do_log = std::stoi(argv[5]); + const int nrepeat = std::stoi(argv[6]); + + const index_t M = std::stoi(argv[7]); + const index_t N = std::stoi(argv[8]); + const index_t K = std::stoi(argv[9]); + +#if 0 + using ab_data_t = float; + using acc_data_t = float; + using c_data_t = float; +#elif 1 + using ab_data_t = half_t; + using acc_data_t = float; + using c_data_t = half_t; +#elif 1 + using ab_data_t = int8_t; + using acc_data_t = int32_t; + using c_data_t = int8_t; +#endif + + std::vector a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); + std::vector a_strides_host(2), b_strides_host(2), c_strides_host(2); + + if(layout == GemmMatrixLayout::MK_KN_MN) + { + a_lengths_host[0] = static_cast(M); + a_lengths_host[1] = static_cast(K); + a_strides_host[0] = static_cast(K); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + a_lengths_host[0] = static_cast(M); + a_lengths_host[1] = static_cast(K); + a_strides_host[0] = static_cast(K); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(N); + b_lengths_host[1] = static_cast(K); + b_strides_host[0] = static_cast(K); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + a_lengths_host[0] = static_cast(K); + a_lengths_host[1] = static_cast(M); + a_strides_host[0] = static_cast(M); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + a_lengths_host[0] = static_cast(K); + a_lengths_host[1] = static_cast(M); + a_strides_host[0] = static_cast(M); + a_strides_host[1] = static_cast(1); + + b_lengths_host[0] = static_cast(N); + b_lengths_host[1] = static_cast(K); + b_strides_host[0] = static_cast(K); + b_strides_host[1] = static_cast(1); + + c_lengths_host[0] = static_cast(M); + c_lengths_host[1] = static_cast(N); + c_strides_host[0] = static_cast(N); + c_strides_host[1] = static_cast(1); + } + else + { + std::runtime_error("wrong! not implemented"); + } + + Tensor a(a_lengths_host, a_strides_host); + Tensor b(b_lengths_host, b_strides_host); + Tensor c_host(c_lengths_host, c_strides_host); + Tensor c_device(c_lengths_host, c_strides_host); + + std::cout << "layout: " << layout << std::endl; + ostream_HostTensorDescriptor(a.mDesc, std::cout << "a: "); + ostream_HostTensorDescriptor(b.mDesc, std::cout << "b: "); + ostream_HostTensorDescriptor(c_host.mDesc, std::cout << "c: "); + + std::size_t num_thread = std::thread::hardware_concurrency(); + + switch(init_method) + { + case 0: + // no initialization + break; + case 1: + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 2: + a.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + case 3: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_1{}, num_thread); + break; + case 4: + a.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + b.GenerateTensorValue(GeneratorTensor_2{-5, 5}, num_thread); + break; + default: + a.GenerateTensorValue(GeneratorTensor_3{0.0, 1.0}, num_thread); + b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); + } + + auto f_make_for_device_mk_kn_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_mk_nk_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_km_kn_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + + auto f_make_for_device_km_nk_mn = [&]() { + const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); + const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); + const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); + + return make_tuple(a_desc, b_desc, c_desc); + }; + +#if USE_GEMM_XDL_MK_KN_MN + if(algo == GemmAlgo::Xdl_MK_KN_MN) + { + if(layout != GemmMatrixLayout::MK_KN_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_mk_kn_mn(); + + device_gemm_xdlops_mk_kn_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_MK_NK_MN + if(algo == GemmAlgo::Xdl_MK_NK_MN) + { + if(layout != GemmMatrixLayout::MK_NK_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_mk_nk_mn(); + + device_gemm_xdlops_mk_nk_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_KN_MN + if(algo == GemmAlgo::Xdl_KM_KN_MN) + { + if(layout != GemmMatrixLayout::KM_KN_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_km_kn_mn(); + + device_gemm_xdlops_km_kn_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_NK_MN + if(algo == GemmAlgo::Xdl_KM_NK_MN) + { + if(layout != GemmMatrixLayout::KM_NK_MN) + { + throw std::runtime_error("wrong! layout"); + } + + const auto descs = f_make_for_device_km_nk_mn(); + + device_gemm_xdlops_km_nk_mn( + descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + } +#endif + + if(do_verification) + { + host_gemm(a, b, c_host, layout); + + check_error(c_host, c_device); + + if(do_log) + { + LogRangeAsType(std::cout << "a : ", a.mData, ",") << std::endl; + LogRangeAsType(std::cout << "b: ", b.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_host : ", c_host.mData, ",") << std::endl; + LogRangeAsType(std::cout << "c_device: ", c_device.mData, ",") << std::endl; + } + } +} diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp new file mode 100644 index 0000000000..f0f35a78b9 --- /dev/null +++ b/host/host_tensor/include/gemm_common.hpp @@ -0,0 +1,12 @@ +#ifndef GEMM_COMMON_HPP +#define GEMM_COMMON_HPP + +enum GemmMatrixLayout +{ + MK_KN_MN, // 0 + MK_NK_MN, // 1 + KM_KN_MN, // 2 + KM_NK_MN, // 3 +}; + +#endif diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp new file mode 100644 index 0000000000..97cf245054 --- /dev/null +++ b/host/host_tensor/include/host_gemm.hpp @@ -0,0 +1,87 @@ +#pragma once +#include "host_tensor.hpp" +#include "gemm_common.hpp" + +template +void host_gemm(const Tensor& a, + const Tensor& b, + Tensor& c, + const GemmMatrixLayout layout) +{ + if(layout == GemmMatrixLayout::MK_KN_MN) + { + auto f_mk_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_MN) + { + auto f_mk_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_MN) + { + auto f_km_kn_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_MN) + { + auto f_km_nk_mn = [&](auto m, auto n) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(m, n) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else + { + throw std::runtime_error("wrong! not supported layout"); + } +} diff --git a/script/run.sh b/script/run.sh index ecb5c85d81..3b383fcf3a 100755 --- a/script/run.sh +++ b/script/run.sh @@ -12,13 +12,16 @@ #export OLC_DEBUG_HIP_DUMP=1 #export OLC_DEBUG_SAVE_TEMP_DIR=1 - make -j conv_fwd_driver_offline - make -j conv_bwd_driver_offline - make -j conv_fwd_driver_online - #rm -rf /root/_hip_binary_kernels_/ #rm -rf /tmp/olCompile* +#make -j conv_fwd_driver_offline +#make -j conv_bwd_driver_offline +#make -j conv_wrw_driver_offline +#make -j conv_fwd_driver_online + + make -j gemm_driver_offline + LAYOUT=$1 ALGO=$2 VERIFY=$3 @@ -30,7 +33,7 @@ REPEAT=$6 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 - ./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 #./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 @@ -44,4 +47,12 @@ REPEAT=$6 #./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#./host/driver_offline/conv_wrw_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 + +#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 + +################################################ layout algo verify init log repeat M___ N___ K___ +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 + ./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 +#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 From f3acd2510b21b2a589b4bf38c328d8232bc96812 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Sun, 5 Sep 2021 12:57:57 -0500 Subject: [PATCH 09/15] Add a version of Merge transform that use integerdivision and mod (#25) * add Merg_v3_division_mod * refactor --- .../multi_index_transform.hpp | 123 ++++++++++++++++++ .../multi_index_transform_helper.hpp | 30 +++-- 2 files changed, 145 insertions(+), 8 deletions(-) diff --git a/composable_kernel/include/tensor_description/multi_index_transform.hpp b/composable_kernel/include/tensor_description/multi_index_transform.hpp index 42a5a875b7..1a25e99f3b 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform.hpp @@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division } }; +// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to +// be used for low_lengths that are known at compile time and are power of 2, otherwise performance +// will be very bad +template +struct Merge_v3_division_mod +{ + static constexpr index_t NDimLow = LowLengths::Size(); + + using LowerIndex = MultiIndex; + using UpperIndex = MultiIndex<1>; + + using LowLengthsScan = + decltype(container_reverse_exclusive_scan(LowLengths{}, math::multiplies{}, Number<1>{})); + + using UpLengths = + decltype(make_tuple(container_reduce(LowLengths{}, math::multiplies{}, Number<1>{}))); + + LowLengths low_lengths_; + LowLengthsScan low_lengths_scan_; + UpLengths up_lengths_; + + __host__ __device__ constexpr Merge_v3_division_mod() = default; + + __host__ __device__ constexpr Merge_v3_division_mod(const LowLengths& low_lengths) + : low_lengths_{low_lengths}, + low_lengths_scan_{ + container_reverse_exclusive_scan(low_lengths, math::multiplies{}, Number<1>{})}, + up_lengths_{make_tuple(container_reduce(low_lengths, math::multiplies{}, Number<1>{}))} + { + static_assert(LowerIndex::Size() == NDimLow, "wrong!"); + } + + __host__ __device__ static constexpr index_t GetNumOfLowerDimension() { return NDimLow; } + + __host__ __device__ static constexpr index_t GetNumOfUpperDimension() { return 1; } + + __host__ __device__ constexpr const auto& GetUpperLengths() const { return up_lengths_; } + + template + __host__ __device__ constexpr void CalculateLowerIndex(LowIdx& idx_low, + const UpIdx& idx_up) const + { + static_assert(LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + index_t tmp = idx_up[Number<0>{}]; + + // division and mod + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + idx_low(i) = tmp / this->low_lengths_scan_[i]; + tmp %= this->low_lengths_scan_[i]; + }); + + idx_low(Number{}) = tmp; + } + + template + __host__ __device__ void UpdateLowerIndex(LowIdxDiff& idx_diff_low, + const UpIdxDiff&, + LowIdx& idx_low, + const UpIdx& idx_up_new, + Number) const + { + static_assert(LowIdxDiff::Size() == NDimLow && UpIdxDiff::Size() == 1 && + LowIdx::Size() == NDimLow && UpIdx::Size() == 1, + "wrong! inconsistent # of dimension"); + + constexpr auto I0 = Number<0>{}; + constexpr auto INm1 = Number{}; + + index_t tmp = idx_up_new[I0]; + + static_for<0, NDimLow - 1, 1>{}([&](auto i) { + const index_t tmp2 = idx_low[i]; + idx_low(i) = tmp / this->low_lengths_scan_[i]; + idx_diff_low(i) = idx_low[i] - tmp2; + tmp %= this->low_lengths_scan_[i]; + }); + + const index_t tmp2 = idx_low[INm1]; + idx_low(INm1) = tmp; + idx_diff_low(INm1) = idx_low[INm1] - tmp2; + } + + __host__ __device__ static constexpr bool IsLinearTransform() { return false; } + + __host__ __device__ static constexpr bool IsValidUpperIndexAlwaysMappedToValidLowerIndex() + { + return true; + } + + __host__ __device__ static constexpr bool IsKnownAtCompileTime() + { + return is_known_at_compile_time::value && + is_known_at_compile_time::value && + is_known_at_compile_time::value; + } + + template + __host__ __device__ static constexpr bool + IsValidUpperIndexMappedToValidLowerIndex(const UpIdx& /* idx_up */) + { + return true; + } + + __host__ __device__ void Print() const + { + printf("{"); + printf("Merge_v3_direct_division_mod, "); + printf("low_lengths_ "); + print_multi_index(low_lengths_); + printf("low_lengths_scan_ "); + print_multi_index(low_lengths_scan_); + printf("up_lengths_ "); + print_multi_index(up_lengths_); + printf("}"); + } +}; + template struct UnMerge { diff --git a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp index 6d4e01888b..32acceb608 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp @@ -52,22 +52,36 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng template __host__ __device__ constexpr auto make_merge_transform(const LowLengths& low_lengths) { -#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION +#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION + return make_merge_transform_v2_magic_division(low_lengths); +#else + return make_merge_transform_v1_carry_check(low_lengths); +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v1_carry_check(const LowLengths& low_lengths) +{ return Merge_v1_carry_check{low_lengths}; -#else -#if 1 - return Merge_v2_magic_division{low_lengths}; -#else - return Merge_v2r2_magic_division{low_lengths}; -#endif -#endif } template __host__ __device__ constexpr auto make_merge_transform_v2_magic_division(const LowLengths& low_lengths) { +#if 1 return Merge_v2_magic_division{low_lengths}; +#else + return Merge_v2r2_magic_division{low_lengths}; +#endif +} + +template +__host__ __device__ constexpr auto +make_merge_transform_v3_division_mod(const LowLengths& low_lengths) +{ + return Merge_v3_division_mod{low_lengths}; } template From 846f462bd414b1cce9114de673f1ed9b360c0ce5 Mon Sep 17 00:00:00 2001 From: zjing14 Date: Wed, 6 Oct 2021 10:13:52 -0500 Subject: [PATCH 10/15] Add VectorType support into StaticBuffer (#27) * init StaticBufferV2 * clean * adopt old output stage for staticBufferV2 * clean * remove hack * clean * clean * clean code * move c_buffer alloc into blockwise gemm * add adaptors for m/n_thread_data_on_grid * adjust blockwise_gemm_xdlops * reorder ops in GEMM hot loop Co-authored-by: Chao Liu --- .../blockwise_gemm_xdlops.hpp | 82 +++--- .../gridwise_gemm_xdlops_v2r3.hpp | 269 ++++-------------- .../include/tensor_operation/xdlops_gemm.hpp | 100 ++----- .../include/utility/amd_xdlops.hpp | 257 +++++------------ .../include/utility/static_buffer.hpp | 92 ++++++ 5 files changed, 282 insertions(+), 518 deletions(-) diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index a8236737df..36c6783204 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -10,6 +10,7 @@ namespace ck { template {}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + StaticBufferV2, MRepeat * NRepeat, true> + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + __device__ static auto GetWaveIdx() { const index_t thread_id = get_thread_local_1d_id(); @@ -162,7 +167,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { return transform_tensor_descriptor( AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), + make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -174,7 +179,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { return transform_tensor_descriptor( BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), + make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -195,48 +200,43 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - vector_type a_thread_vec; - - vector_type b_thread_vec; - - static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { // read A a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, - make_tuple(k0, I0, I0, I0, I0), + make_tuple(I0, m0, I0, I0, I0), a_block_buf, a_thread_desc_, make_tuple(I0, I0, I0, I0, I0), a_thread_buf); - // read B - b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, - make_tuple(k0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - b_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, + make_tuple(I0, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + b_thread_buf); - using mfma_input_type = typename vector_type::type; + static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, K1, 1>{}([&](auto i) { a_thread_vec.template AsType()(i) = a_thread_buf - [Number{}]; - }); - - static_for<0, K1, 1>{}([&](auto i) { + [Number{}]; b_thread_vec.template AsType()(i) = b_thread_buf - [Number{}]; + [Number{}]; }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + using mfma_input_type = + typename vector_type::type; - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0)); + + xdlops_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVector(Number{})); }); }); }); @@ -244,35 +244,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 private: // A[K, M] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, I1, I1, Number{})); + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); // B[K, N] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); - static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence, Sequence<0, 1, 2, 3, 4>, 4, K1, - 1>; + K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence, Sequence<0, 1, 2, 3, 4>, 4, K1, - 1>; + K1>; AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 3e4d74e9d8..c6f491dc47 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -142,6 +142,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -220,6 +221,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - constexpr auto c_mr_nr_blk_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = - blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor(); - constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize(); - - StaticBuffer, - c_mr_nr_blk_desc.GetElementSpaceSize(), - true> - c_thread_buf; + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -460,9 +452,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -477,224 +478,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r3, + Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, true>{ + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - make_multi_index(0, - 0, - 0, - 0, - m_thread_data_on_grid / (M3 * M4), - m_thread_data_on_grid % (M3 * M4) / M4, - m_thread_data_on_grid % M4, - n_thread_data_on_grid)}; + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2])}; - auto init_copy = [&](auto c_thread_idx_) { - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - - return c_thread_idx_; - }; - - auto mrepeat_plus_copy = [&](auto c_thread_idx_) { - constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - mrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - }; - - auto nrepeat_plus_copy = [&](auto c_thread_idx_) { - constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - nrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - }; - - auto mrepeat_minus_copy = [&](auto c_thread_idx_) { - constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - mrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - }; - - auto nrepeat_minus_copy = [&](auto c_thread_idx_) { - constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - nrepeat_step_minus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - }; - - static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or - (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or - (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or - (MRepeat == 1 && NRepeat == 1), - "wrong"); - - if constexpr(MRepeat == 4 && NRepeat == 4) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - nrepeat_plus_copy(make_tuple(I0, I3)); - mrepeat_plus_copy(make_tuple(I1, I3)); - nrepeat_minus_copy(make_tuple(I1, I2)); - nrepeat_minus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - nrepeat_plus_copy(make_tuple(I2, I1)); - nrepeat_plus_copy(make_tuple(I2, I2)); - nrepeat_plus_copy(make_tuple(I2, I3)); - mrepeat_plus_copy(make_tuple(I3, I3)); - nrepeat_minus_copy(make_tuple(I3, I2)); - nrepeat_minus_copy(make_tuple(I3, I1)); - nrepeat_minus_copy(make_tuple(I3, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - mrepeat_plus_copy(make_tuple(I3, I0)); - nrepeat_plus_copy(make_tuple(I3, I1)); - mrepeat_minus_copy(make_tuple(I2, I1)); - mrepeat_minus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - mrepeat_plus_copy(make_tuple(I1, I2)); - mrepeat_plus_copy(make_tuple(I2, I2)); - mrepeat_plus_copy(make_tuple(I3, I2)); - nrepeat_plus_copy(make_tuple(I3, I3)); - mrepeat_minus_copy(make_tuple(I2, I3)); - mrepeat_minus_copy(make_tuple(I1, I3)); - mrepeat_minus_copy(make_tuple(I0, I3)); - } - } - else if constexpr(MRepeat == 4 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - mrepeat_plus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - nrepeat_plus_copy(make_tuple(I2, I1)); - mrepeat_plus_copy(make_tuple(I3, I1)); - nrepeat_minus_copy(make_tuple(I3, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - mrepeat_plus_copy(make_tuple(I3, I0)); - nrepeat_plus_copy(make_tuple(I3, I1)); - mrepeat_minus_copy(make_tuple(I2, I1)); - mrepeat_minus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 4) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - nrepeat_plus_copy(make_tuple(I0, I3)); - mrepeat_plus_copy(make_tuple(I1, I3)); - nrepeat_minus_copy(make_tuple(I1, I2)); - nrepeat_minus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - nrepeat_plus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - mrepeat_plus_copy(make_tuple(I1, I2)); - nrepeat_plus_copy(make_tuple(I1, I3)); - mrepeat_minus_copy(make_tuple(I0, I3)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - mrepeat_plus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - nrepeat_plus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 1) - { - init_copy(make_tuple(I0, I0)); - mrepeat_plus_copy(make_tuple(I1, I0)); - } - else if constexpr(MRepeat == 1 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - nrepeat_plus_copy(make_tuple(I0, I1)); - } - else if constexpr(MRepeat == 1 && NRepeat == 1) - { - init_copy(make_tuple(I0, I0)); - } + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); } } }; // namespace ck diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index f945b0fdf5..10633f8f32 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -44,15 +44,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); + intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); } }; @@ -71,15 +66,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); + intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); } }; @@ -98,15 +88,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); + intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); } }; @@ -125,15 +110,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); + intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); } }; @@ -153,15 +133,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); + intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); } }; @@ -180,15 +155,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); + intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); } }; @@ -207,15 +177,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); + intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); } }; @@ -234,15 +199,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); + intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); } }; @@ -261,15 +221,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); + intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); } }; @@ -288,15 +243,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); + intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); } }; @@ -732,7 +682,7 @@ struct XdlopsGemm return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } - template + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert(is_same::value || is_same::value || @@ -740,8 +690,7 @@ struct XdlopsGemm "base base_type must be float, half, ushort!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { - mfma_instr.template run( - p_a_wave[k], p_b_wave[k], p_c_thread); + mfma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); }); } @@ -819,8 +768,9 @@ struct XdlopsGemm static constexpr auto mfma_instr = mfma.selected_mfma; - static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto KPerThread = mfma.GetKPerThread(); + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetKPerThread(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() { diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp index da74fe1d48..083e47fbf1 100644 --- a/composable_kernel/include/utility/amd_xdlops.hpp +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); -template +template struct intrin_mfma_f32_32x32x1f32; -template -struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x1f32<64, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); } }; -template -struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x1f32<32, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); } }; -template +template struct intrin_mfma_f32_32x32x2f32; -template -struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> +template <> +struct intrin_mfma_f32_32x32x2f32<32, 32> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x2f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x4f32; -template -struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> +template <> +struct intrin_mfma_f32_16x16x4f32<16, 16> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x4f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x1f32; -template -struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> +template <> +struct intrin_mfma_f32_16x16x1f32<16, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 2, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); } }; -template +template struct intrin_mfma_f32_4x4x1f32; -template -struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x1f32<4, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); } }; -template -struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x1f32<8, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); } }; -template +template struct intrin_mfma_f32_32x32x4f16; -template -struct intrin_mfma_f32_32x32x4f16<64, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x4f16<64, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); } }; -template -struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x4f16<32, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); } }; -template +template struct intrin_mfma_f32_32x32x8f16; -template -struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> +template <> +struct intrin_mfma_f32_32x32x8f16<32, 32> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x8f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x16f16; -template -struct intrin_mfma_f32_16x16x16f16<16, 16, COffset> +template <> +struct intrin_mfma_f32_16x16x16f16<16, 16> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x16f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x4f16; -template -struct intrin_mfma_f32_16x16x4f16<16, 64, COffset> +template <> +struct intrin_mfma_f32_16x16x4f16<16, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 2, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); } }; -template +template struct intrin_mfma_f32_4x4x4f16; -template -struct intrin_mfma_f32_4x4x4f16<4, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x4f16<4, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); } }; -template -struct intrin_mfma_f32_4x4x4f16<8, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x4f16<8, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); } }; @@ -448,7 +340,6 @@ template __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec16_1_t::VecType reg_c); - template <> __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, const ushort2_t* reg_b, diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp index cd67b8a0be..9615d10c59 100644 --- a/composable_kernel/include/utility/static_buffer.hpp +++ b/composable_kernel/include/utility/static_buffer.hpp @@ -55,6 +55,98 @@ struct StaticBuffer : public StaticallyIndexedArray __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } }; +template +struct StaticBufferV2 : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + using VecBaseType = typename T::d1_t; + + __host__ __device__ static constexpr index_t GetVectorSize() + { + return sizeof(typename T::type) / sizeof(VecBaseType); + } + + static constexpr index_t vector_size = GetVectorSize(); + + VecBaseType invalid_element_value_ = VecBaseType{0}; + + T invalid_vec_value_ = T{0}; + + __host__ __device__ constexpr StaticBufferV2() : base{} {} + + __host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value) + : base{}, + invalid_vec_value_{invalid_element_value}, + invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() + { + return BufferAddressSpace; + } + + template + __host__ __device__ constexpr auto& GetVector(Number vec_id) + { + return this->At(vec_id); + } + + template + __host__ __device__ constexpr const auto& GetVector(Number vec_id) const + { + return this->At(vec_id); + } + + template + __host__ __device__ constexpr auto& GetElement(Number i, bool) + { + constexpr auto vec_id = Number{}; + constexpr auto vec_off = Number{}; + + return this->At(vec_id).template AsType()(vec_off); + } + + template + __host__ __device__ constexpr auto GetElement(Number i, bool is_valid_element) const + { + constexpr auto vec_id = Number{}; + constexpr auto vec_off = Number{}; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return is_valid_element ? this->At(vec_id).template AsType()[vec_off] + : VecBaseType{0}; + } + else + { + return is_valid_element ? this->At(vec_id).template AsType()[vec_off] + : invalid_element_value_; + } + } + + template + __host__ __device__ constexpr auto operator[](Number i) const + { + return GetElement(i, true); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return GetElement(i, true); + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } +}; + template __host__ __device__ constexpr auto make_static_buffer(Number) { From b3e8d57d51300b88b591900621f71b6a1b3a7acc Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Wed, 6 Oct 2021 11:12:36 -0500 Subject: [PATCH 11/15] Tweak GEMM kernel (#38) * add parameters * tweak gemm * tweak * update conv * update script * adding bwd 1x1 * update script * adding 1x1 bwd * debugging bwd 1x1 failure * update script * update script * test * test v100 * clean up --- ...lution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp | 103 ++++- .../multi_index_transform_helper.hpp | 2 +- .../gridwise_gemm_xdlops_v2r3.hpp | 173 ++++++-- composable_kernel/include/utility/config.hpp | 4 +- host/driver_offline/include/debug.hpp | 13 + ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 44 +- ...icit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp | 271 ++++++++---- ..._gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp | 389 ++++++++++++++++++ ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 29 +- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 76 +++- .../include/device_gemm_xdlops_km_kn_mn.hpp | 332 +++++++++++++-- .../include/device_gemm_xdlops_km_kn_nm.hpp | 263 ++++++++++++ .../include/device_gemm_xdlops_km_nk_mn.hpp | 332 +++++++++++++-- .../include/device_gemm_xdlops_km_nk_nm.hpp | 263 ++++++++++++ .../include/device_gemm_xdlops_mk_kn_mn.hpp | 334 +++++++++++++-- .../include/device_gemm_xdlops_mk_kn_nm.hpp | 291 +++++++++++++ .../include/device_gemm_xdlops_mk_nk_mn.hpp | 383 ++++++++++++++--- .../include/device_gemm_xdlops_mk_nk_nm.hpp | 347 ++++++++++++++++ .../include/driver_gemm_xdlops_v2r3.hpp | 20 +- .../src/conv_bwd_driver_offline.cpp | 59 ++- .../src/conv_fwd_driver_offline.cpp | 3 +- .../src/conv_wrw_driver_offline.cpp | 3 +- .../src/gemm_driver_offline.cpp | 192 +++++---- host/host_tensor/include/device.hpp | 4 + host/host_tensor/include/gemm_common.hpp | 4 + host/host_tensor/include/host_gemm.hpp | 72 ++++ script/docker-rocm4.3.1.sh | 14 + script/run.sh | 151 +++++-- 28 files changed, 3642 insertions(+), 529 deletions(-) create mode 100644 host/driver_offline/include/debug.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp create mode 100644 host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp create mode 100755 script/docker-rocm4.3.1.sh diff --git a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp index 9c60e8c3ac..fa78d76965 100644 --- a/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp +++ b/composable_kernel/include/problem_transform/transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp @@ -21,8 +21,8 @@ template __host__ __device__ constexpr auto transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( @@ -33,8 +33,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const ConvDilations& conv_dilations, const InLeftPads& in_left_pads, const InRightPads& in_right_pads, - Number, - Number, + IYTilda i_ytilda, + IXTilda i_xtilda, Number) { constexpr auto I0 = Number<0>{}; @@ -42,9 +42,7 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( constexpr auto I2 = Number<2>{}; constexpr auto I3 = Number<3>{}; - constexpr auto GemmK1 = Number{}; - constexpr auto IYTilda = Number{}; - constexpr auto IXTilda = Number{}; + constexpr auto GemmK1 = Number{}; const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); @@ -98,8 +96,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const auto WTildaSlice = IWTildaSliceEnd - IWTildaSliceBegin; // GemmK is different for each GEMM - const auto YDotSlice = math::integer_divide_ceil(Y - IYTilda, YTilda); - const auto XDotSlice = math::integer_divide_ceil(X - IXTilda, XTilda); + const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilda, YTilda); + const auto XDotSlice = math::integer_divide_ceil(X - i_xtilda, XTilda); const auto K1 = GemmK1; const auto K0 = K / K1; @@ -183,8 +181,8 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_slice_transform(YDot, I0, YDotSlice), make_slice_transform(XDot, I0, XDotSlice), - make_freeze_transform(IYTilda), - make_freeze_transform(IXTilda), + make_freeze_transform(i_ytilda), + make_freeze_transform(i_xtilda), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, Sequence<1>{}, @@ -241,9 +239,9 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( const auto in_n_htildaslice_wtildaslice_c_grid_desc = transform_tensor_descriptor( in_n_ytilda_htilda_xtilda_wtilda_c_grid_desc, make_tuple(make_pass_through_transform(N), - make_freeze_transform(IYTilda), + make_freeze_transform(i_ytilda), make_slice_transform(HTilda, IHTildaSliceBegin, HTildaSlice), - make_freeze_transform(IXTilda), + make_freeze_transform(i_xtilda), make_slice_transform(WTilda, IWTildaSliceBegin, WTildaSlice), make_pass_through_transform(C)), make_tuple(Sequence<0>{}, @@ -271,5 +269,84 @@ transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( in_gemmm_gemmn_grid_desc); } +// A: out +// B: wei +// C: in +// Number of GEMMs = 1 +// GemmM = N * Ho * Wo +// GemmN = C +// GemmK = K +template +__host__ __device__ constexpr auto +transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const TensorDescriptor& /* wei_k_y_x_c_grid_desc */, + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const ConvStrides& conv_strides, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto K1 = GemmK1; + const auto K0 = K / K1; + + // A: output tensor + const auto out_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), + make_unmerge_transform(make_tuple(K0, K1))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + + // B: weight tensor + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C)), + make_tuple(make_unmerge_transform(make_tuple(K0, K1)), make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: input tensor + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(I1, Ho), make_tuple(I1, ConvStrideH)), + make_embed_transform(make_tuple(I1, Wo), make_tuple(I1, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_freeze_transform(I0), + make_freeze_transform(I0), + make_merge_transform(make_tuple(N, Ho, Wo)), + make_pass_through_transform(C)), + make_tuple(Sequence<1>{}, Sequence<3>{}, Sequence<0, 2, 4>{}, Sequence<5>{}), + make_tuple(Sequence<>{}, Sequence<>{}, Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc); +} + } // namespace ck #endif diff --git a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp index 32acceb608..9a73799173 100644 --- a/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp +++ b/composable_kernel/include/tensor_description/multi_index_transform_helper.hpp @@ -31,7 +31,7 @@ __host__ __device__ constexpr auto make_left_pad_transform( return LeftPad{low_length, left_pad}; } -template +template __host__ __device__ constexpr auto make_right_pad_transform( const LowLength& low_length, const RightPadLength& right_pad, diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index c6f491dc47..e3b0054bec 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -29,7 +29,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const AK0MK1GridDesc a_k0_m_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CBlockClusterAdaptor c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -132,7 +132,9 @@ template + bool CAccessOrderMRepeatNRepeat, + bool ABlockLdsExtraM, + bool BBlockLdsExtraN> struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { static constexpr auto I0 = Number<0>{}; @@ -152,14 +154,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -171,29 +193,45 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); } + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} __host__ __device__ static constexpr bool CheckValidity(const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CMNGridDesc& c_m_n_grid_desc) + const CMNGridDesc& c_m_n_grid_desc, + index_t M01, + index_t N01) { - // TODO: turn on this static_assert(is_known_at_compile_time>::value, "wrong! K1 need to be known at compile-time"); - const auto M = a_k0_m_k1_grid_desc.GetLength(I1); - const auto N = b_k0_n_k1_grid_desc.GetLength(I1); - const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); - static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && (NPerBlock % (NRepeat * NPerXDL)) == 0, "Invalid tuning param!"); + const auto M = a_k0_m_k1_grid_desc.GetLength(I1); + const auto N = b_k0_n_k1_grid_desc.GetLength(I1); + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_k0_n_k1_grid_desc.GetLength(I0) && K1 == a_k0_m_k1_grid_desc.GetLength(I2) && + K1 == b_k0_n_k1_grid_desc.GetLength(I2))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) - return (M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && - K0 == b_k0_n_k1_grid_desc.GetLength(I0) && - K1 == a_k0_m_k1_grid_desc.GetLength(I2) && - K1 == b_k0_n_k1_grid_desc.GetLength(I2)) && - (M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0); + return true; } __host__ __device__ static constexpr index_t @@ -212,11 +250,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 { constexpr auto max_lds_align = K1; - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1, 3>{})); + + const auto c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3>{}), + make_tuple(Sequence<0>{})); + const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(M0, N0))), - make_tuple(Sequence<0, 1>{}), - make_tuple(Sequence<0>{})); -#elif 1 - const auto c_blockid_to_m0_n0_block_cluster_adaptor = - make_single_stage_tensor_adaptor(make_tuple(make_merge_transform(make_tuple(N0, M0))), - make_tuple(Sequence<1, 0>{}), - make_tuple(Sequence<0>{})); -#endif + chain_tensor_adaptors(m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_m00_m01_n00_n01_block_cluster_adaptor); return c_blockid_to_m0_n0_block_cluster_adaptor; } using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1)); __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -296,14 +367,34 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto max_lds_align = K1; // A matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // B matrix in LDS memory, dst of blockwise copy - // be careful of LDS alignment - constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); // A matrix blockwise copy auto a_blockwise_copy = diff --git a/composable_kernel/include/utility/config.hpp b/composable_kernel/include/utility/config.hpp index c229162d9b..5ee4bb9c64 100644 --- a/composable_kernel/include/utility/config.hpp +++ b/composable_kernel/include/utility/config.hpp @@ -90,8 +90,8 @@ #endif // pass tensor descriptor by value or void* -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 0 -#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 1 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE 1 +#define CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER 0 // merge transformation use magic number division #define CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION 0 diff --git a/host/driver_offline/include/debug.hpp b/host/driver_offline/include/debug.hpp new file mode 100644 index 0000000000..72fd0763ba --- /dev/null +++ b/host/driver_offline/include/debug.hpp @@ -0,0 +1,13 @@ +#ifndef DEBUG_HPP +#define DEBUG_HPP + +namespace debug { +namespace debug_driver_gemm_xdlops_v2r3 { + +// these vars are on host, they control block_id to C matrix tile idx (m0, n0) mapping +static ck::index_t M01 = 1; +static ck::index_t N01 = 1; + +} // namespace debug_driver_gemm_xdlops_v2r3 +} // namespace debug +#endif diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index 8f49473563..b5ff1db296 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -48,8 +48,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); -#if 1 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 +#if 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -76,7 +76,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 +#elif 0 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -105,7 +105,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; #elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; #elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -159,34 +159,6 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 0 - // [M, N, K0, K1] = [256, 128, 4, 4] - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerXDL = 32; - constexpr index_t GemmNPerXDL = 32; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 4; - constexpr index_t NRepeat = 2; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; #endif @@ -294,13 +266,17 @@ void device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk( decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN >(static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), wei_gemmk0_gemmm_gemmk1_grid_desc, out_gemmk0_gemmn_gemmk1_grid_desc, in_gemmm_gemmn_grid_desc, + debug_driver_gemm_xdlops_v2r3::M01, + debug_driver_gemm_xdlops_v2r3::N01, wei_gemmk0_gemmm_gemmk1_grid_step_hacks, out_gemmk0_gemmn_gemmk1_grid_step_hacks, in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp index 2cbae2daf3..28d6226f1b 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp @@ -49,7 +49,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -77,7 +77,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -104,8 +104,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -133,7 +133,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -159,25 +159,93 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif - const auto descs = - transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk(out_n_ho_wo_k_desc, - wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - I0, - I0, - Number{}); - - const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto in_gemmm_gemmn_grid_desc = descs[I2]; - // HACK: hacks that control index calculation when iterating over A, B, C matrix constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 @@ -185,7 +253,8 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 2+: gemmk1 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: gemmk0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0>{}, // 1-: gemmm - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: gemmk1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 2-: + // gemmk1 constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: gemmk0 @@ -215,7 +284,7 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - //clang-format on + // clang-format on constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0>{}; @@ -225,64 +294,110 @@ void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk for(index_t i = 0; i < 5; ++i) { - float ave_time = driver_gemm_xdlops_v2r3< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(out_gemmk0_gemmm_gemmk1_grid_desc), - decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), - decltype(in_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - GemmK1, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<2, 0, 1>, - Sequence<0, 2, 1>, - 1, - GemmBBlockTransferSrcScalarPerVector_GemmN, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH); + const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW); + + const auto YTilda = ConvStrideH / GcdStrideDilationH; + const auto XTilda = ConvStrideW / GcdStrideDilationW; + + float ave_time = 0; + + for(index_t i_ytilda = 0; i_ytilda < YTilda; ++i_ytilda) + { + for(index_t i_xtilda = 0; i_xtilda < XTilda; ++i_xtilda) + { + const auto descs = + transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk( + out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + i_ytilda, + i_xtilda, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + const auto GemmK0 = out_gemmk0_gemmm_gemmk1_grid_desc.GetLength(I0); + + if(GemmK0 != 0) + { + ave_time += driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy #if 0 - Sequence<0, 2, 4, 5, 6, 1, 3, 7>, + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, #else - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, #endif - 7, - GemmCThreadTransferDstScalarPerVector, - decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), - decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - true // CAccessOrderMRepeatNRepeat - >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - wei_gemmk0_gemmn_gemmk1_grid_desc, - in_gemmm_gemmn_grid_desc, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_step_hacks, - in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + true, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + } + } + } { const auto N = out_n_ho_wo_k_lengths[I0]; diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp new file mode 100644 index 0000000000..d6955ec000 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp @@ -0,0 +1,389 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations&, + const InLeftPads&, + const InRightPads&, + Tensor& in_n_hi_wi_c, + const Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0>{}, // 1+: gemmm + Sequence<0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: gemmk0 + Sequence<0, 0, 0>{}, // 1-: gemmm + Sequence<0, 0, 0>{})); // 2-: gemmk1 + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: gemmk0 + Sequence<0, 0, 0>{}, // 1+: gemmn + Sequence<0, 0, 0>{}), // 2+: gemmk1 + make_tuple(Sequence<0, 0, 0>{}, // 0-: Gemmk0 + Sequence<0, 0, 0>{}, // 1-: Gemmn + Sequence<0, 0, 0>{})); // 2-: Gemmk1 + + // clang-format off + constexpr auto in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple( + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple( + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + // clang-format on + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + constexpr auto wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + const auto descs = transform_backward_data_convolution_into_gemm_v4r1r2_nhwc_kyxc_nhwk_1x1( + out_n_ho_wo_k_desc, + wei_k_y_x_c_desc, + in_n_hi_wi_c_desc, + conv_strides, + Number{}); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto wei_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto in_gemmm_gemmn_grid_desc = descs[I2]; + + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TInWei, + TAcc, + TOut, + InMemoryDataOperationEnum_t::Set, + decltype(out_gemmk0_gemmm_gemmk1_grid_desc), + decltype(wei_gemmk0_gemmn_gemmk1_grid_desc), + decltype(in_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerWave, + GemmNPerWave, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<2, 0, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy +#if 0 + Sequence<0, 2, 4, 5, 6, 1, 3, 7>, +#else + Sequence<0, 1, 2, 3, 4, 5, 6, 7>, +#endif + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + true, // CAccessOrderMRepeatNRepeat + false, // ABlockLdsExtraM + false // BBlockLdsExtraN + >(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + wei_gemmk0_gemmn_gemmk1_grid_desc, + in_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_step_hacks, + in_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + in_n_hi_wi_c_device_buf.FromDevice(in_n_hi_wi_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index e97bc9c1c7..b8ecfb4be9 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -203,18 +203,23 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false>(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), - out_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - wei_gemmm_gemmn_grid_desc, - out_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, - out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN + >(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); float perf = static_cast(calculate_convolution_flops( in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 52432664de..01e5c57ab4 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -49,7 +49,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); #if 0 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [256, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -77,7 +77,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 128, for fp32 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -105,7 +105,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [256, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 256, 4, 8], C = 256, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -133,7 +133,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 256; @@ -160,8 +160,8 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -189,7 +189,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 constexpr index_t BlockSize = 256; constexpr index_t GemmMPerBlock = 128; @@ -215,6 +215,62 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; #endif @@ -316,13 +372,17 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( decltype(out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), decltype(wei_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), in_gemmk0_gemmm_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc, out_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, in_gemmk0_gemmm_gemmk1_grid_step_hacks, wei_gemmk0_gemmn_gemmk1_grid_step_hacks, out_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp index d9169649e6..c44aa7d9a2 100644 --- a/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_mn.hpp @@ -4,16 +4,8 @@ #include "host_tensor.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template -void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, - const BDesc& b_k_n_grid_desc, - const CDesc& c_m_n_grid_desc, - const Tensor& a_k_m, +template +void device_gemm_xdlops_km_kn_mn(const Tensor& a_k_m, const Tensor& b_k_n, Tensor& c_m_n, ck::index_t nrepeat) @@ -22,9 +14,6 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, std::cout << __func__ << std::endl; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); @@ -60,9 +49,121 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4], C = 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 256; @@ -88,46 +189,185 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_k_m_grid_desc.GetLength(I0); - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_k_n.mDesc.GetLengths()[1]; constexpr auto K1Number = Number{}; const auto K0 = K / K1Number; const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_k_m_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); const auto b_k0_n_k1_grid_desc = - transform_tensor_descriptor(b_k_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 - constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 @@ -147,9 +387,9 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; for(index_t i = 0; i < 5; ++i) { @@ -194,13 +434,17 @@ void device_gemm_xdlops_km_kn_mn(const ADesc& a_k_m_grid_desc, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, a_k0_m_k1_grid_step_hacks, b_k0_n_k1_grid_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp new file mode 100644 index 0000000000..abaaf32113 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_kn_nm.hpp @@ -0,0 +1,263 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_kn_nm(const Tensor& a_k_m, + const Tensor& b_k_n, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_k_n.mDesc.GetLengths()[1]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp index 90e258d581..0a97d361d4 100644 --- a/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_mn.hpp @@ -4,16 +4,8 @@ #include "host_tensor.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template -void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, - const BDesc& b_n_k_grid_desc, - const CDesc& c_m_n_grid_desc, - const Tensor& a_k_m, +template +void device_gemm_xdlops_km_nk_mn(const Tensor& a_k_m, const Tensor& b_n_k, Tensor& c_m_n, ck::index_t nrepeat) @@ -22,9 +14,6 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, std::cout << __func__ << std::endl; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); @@ -60,9 +49,121 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 256; @@ -88,46 +189,185 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 1; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_k_m_grid_desc.GetLength(I0); - const auto M = a_k_m_grid_desc.GetLength(I1); - const auto N = b_n_k_grid_desc.GetLength(I0); + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_n_k.mDesc.GetLengths()[0]; constexpr auto K1Number = Number{}; const auto K0 = K / K1Number; const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_k_m_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(M)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); const auto b_k0_n_k1_grid_desc = - transform_tensor_descriptor(b_n_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_unmerge_transform(make_tuple(K0, K1Number))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 - constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 @@ -147,9 +387,9 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; for(index_t i = 0; i < 5; ++i) { @@ -194,13 +434,17 @@ void device_gemm_xdlops_km_nk_mn(const ADesc& a_k_m_grid_desc, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), static_cast(b_n_k_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, a_k0_m_k1_grid_step_hacks, b_k0_n_k1_grid_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp new file mode 100644 index 0000000000..d51caa3847 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_km_nk_nm.hpp @@ -0,0 +1,263 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_km_nk_nm(const Tensor& a_k_m, + const Tensor& b_n_k, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_k_m_device_buf(sizeof(ABType) * a_k_m.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_k_m_device_buf.ToDevice(a_k_m.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 2; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_M = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_k_m.mDesc.GetLengths()[0]; + const auto M = a_k_m.mDesc.GetLengths()[1]; + const auto N = b_n_k.mDesc.GetLengths()[0]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_k_m.mDesc.GetStrides()[0], + a_k_m.mDesc.GetStrides()[1], + a_k_m.mDesc.GetStrides()[0])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<0, 2, 1>, + 1, + ABlockTransferSrcScalarPerVector_M, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_k_m_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp index ab235d97e7..30ede2517b 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_mn.hpp @@ -4,16 +4,8 @@ #include "host_tensor.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template -void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, - const BDesc& b_k_n_grid_desc, - const CDesc& c_m_n_grid_desc, - const Tensor& a_m_k, +template +void device_gemm_xdlops_mk_kn_mn(const Tensor& a_m_k, const Tensor& b_k_n, Tensor& c_m_n, ck::index_t nrepeat) @@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, std::cout << __func__ << std::endl; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); @@ -33,8 +22,148 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, b_k_n_device_buf.ToDevice(b_k_n.mData.data()); c_m_n_device_buf.ToDevice(c_m_n.mData.data()); -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 256; @@ -88,46 +217,157 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 1; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_m_k_grid_desc.GetLength(I1); - const auto M = a_m_k_grid_desc.GetLength(I0); - const auto N = b_k_n_grid_desc.GetLength(I1); + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_k_n.mDesc.GetLengths()[1]; constexpr auto K1Number = Number{}; const auto K0 = K / K1Number; const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_m_k_grid_desc, - make_tuple(make_pass_through_transform(M), - make_unmerge_transform(make_tuple(K0, K1Number))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); const auto b_k0_n_k1_grid_desc = - transform_tensor_descriptor(b_k_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(K0, K1Number)), - make_pass_through_transform(N)), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 - constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 @@ -147,9 +387,9 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; for(index_t i = 0; i < 5; ++i) { @@ -194,13 +434,17 @@ void device_gemm_xdlops_mk_kn_mn(const ADesc& a_m_k_grid_desc, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_k_n_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, a_k0_m_k1_grid_step_hacks, b_k0_n_k1_grid_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp new file mode 100644 index 0000000000..58ac3880d6 --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_kn_nm.hpp @@ -0,0 +1,291 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_kn_nm(const Tensor& a_m_k, + const Tensor& b_k_n, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_k_n_device_buf(sizeof(ABType) * b_k_n.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_k_n_device_buf.ToDevice(b_k_n.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 2; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_N = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_k_n.mDesc.GetLengths()[1]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_k_n.mDesc.GetStrides()[0], + b_k_n.mDesc.GetStrides()[1], + b_k_n.mDesc.GetStrides()[0])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + BBlockTransferSrcScalarPerVector_N, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_k_n_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp index c68442d127..e99d570413 100644 --- a/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_mn.hpp @@ -4,16 +4,8 @@ #include "host_tensor.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template -void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, - const BDesc& b_n_k_grid_desc, - const CDesc& c_m_n_grid_desc, - const Tensor& a_m_k, +template +void device_gemm_xdlops_mk_nk_mn(const Tensor& a_m_k, const Tensor& b_n_k, Tensor& c_m_n, ck::index_t nrepeat) @@ -22,9 +14,6 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, std::cout << __func__ << std::endl; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); DeviceMem c_m_n_device_buf(sizeof(CType) * c_m_n.mDesc.GetElementSpace()); @@ -34,6 +23,34 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, c_m_n_device_buf.ToDevice(c_m_n.mData.data()); #if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 constexpr index_t BlockSize = 256; @@ -60,9 +77,93 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C = 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 4], C = 32, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [256, 128, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 256; @@ -90,7 +191,7 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, constexpr index_t CThreadTransferDstScalarPerVector = 1; #elif 0 - // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 128; @@ -117,8 +218,36 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; constexpr index_t CThreadTransferDstScalarPerVector = 1; -#elif 1 - // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 constexpr index_t BlockSize = 256; constexpr index_t MPerBlock = 128; @@ -144,46 +273,131 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 64, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 64; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 1, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + constexpr index_t CThreadTransferDstScalarPerVector = 1; #endif - const auto K = a_m_k_grid_desc.GetLength(I1); - const auto M = a_m_k_grid_desc.GetLength(I0); - const auto N = b_n_k_grid_desc.GetLength(I0); + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_n_k.mDesc.GetLengths()[0]; constexpr auto K1Number = Number{}; const auto K0 = K / K1Number; +#if 1 + // non-padded GEMM const auto a_k0_m_k1_grid_desc = - transform_tensor_descriptor(a_m_k_grid_desc, - make_tuple(make_pass_through_transform(M), - make_unmerge_transform(make_tuple(K0, K1Number))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); const auto b_k0_n_k1_grid_desc = - transform_tensor_descriptor(b_n_k_grid_desc, - make_tuple(make_pass_through_transform(N), - make_unmerge_transform(make_tuple(K0, K1Number))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<1>{}, Sequence<0, 2>{})); + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto a_k0_m_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: M - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: M - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 - constexpr auto b_k0_n_k1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0 - Sequence<0, 0, 0>{}, // 1+: N - Sequence<0, 0, 0>{}), // 2+: K1 - make_tuple(Sequence<0, 0, 0>{}, // 0-: K0 - Sequence<0, 0, 0>{}, // 1-: N - Sequence<0, 0, 0>{})); // 2-: K1 + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 @@ -203,9 +417,80 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 - constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; - constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; +#else + // padded GEMM + const auto a_k0_m_k1_grid_desc_tmp = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto MRightPad = math::integer_divide_ceil(M, MPerBlock) * MPerBlock - M; + + const auto a_k0_m_k1_grid_desc = + transform_tensor_descriptor(a_k0_m_k1_grid_desc_tmp, + make_tuple(make_pass_through_transform(K0), + make_right_pad_transform(M, MRightPad), + make_pass_through_transform(K1Number)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc_tmp = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_m_n.mDesc.GetStrides()[0], c_m_n.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc_tmp, + make_tuple(make_right_pad_transform(M, MRightPad), make_pass_through_transform(N)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0>{}, // 0+: K0 + Sequence<0, 0, 0, 0>{}, // 1+: M + Sequence<0, 0, 0, 0>{}), // 2+: K1 + make_tuple(Sequence<0, 0, 0, 0>{}, // 0-: K0 + Sequence<0, 0, 0, 0>{}, // 1-: M + Sequence<0, 0, 0, 0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; +#endif for(index_t i = 0; i < 5; ++i) { @@ -250,13 +535,17 @@ void device_gemm_xdlops_mk_nk_mn(const ADesc& a_m_k_grid_desc, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), - false // CAccessOrderMRepeatNRepeat + false, // CAccessOrderMRepeatNRepeat + true, // ABlockLdsExtraM + true // BBlockLdsExtraN >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), static_cast(b_n_k_device_buf.GetDeviceBuffer()), static_cast(c_m_n_device_buf.GetDeviceBuffer()), a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, a_k0_m_k1_grid_step_hacks, b_k0_n_k1_grid_step_hacks, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, diff --git a/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp b/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp new file mode 100644 index 0000000000..a12cf0733a --- /dev/null +++ b/host/driver_offline/include/device_gemm_xdlops_mk_nk_nm.hpp @@ -0,0 +1,347 @@ +#pragma once +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" + +template +void device_gemm_xdlops_mk_nk_nm(const Tensor& a_m_k, + const Tensor& b_n_k, + Tensor& c_n_m, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + DeviceMem a_m_k_device_buf(sizeof(ABType) * a_m_k.mDesc.GetElementSpace()); + DeviceMem b_n_k_device_buf(sizeof(ABType) * b_n_k.mDesc.GetElementSpace()); + DeviceMem c_n_m_device_buf(sizeof(CType) * c_n_m.mDesc.GetElementSpace()); + + a_m_k_device_buf.ToDevice(a_m_k.mData.data()); + b_n_k_device_buf.ToDevice(b_n_k.mData.data()); + c_n_m_device_buf.ToDevice(c_n_m.mData.data()); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 4>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 4; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 4>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 4; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 4; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 256; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 256; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 128, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 4, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 32, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 4, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 32, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 128; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 2, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C = 32, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t MPerBlock = 64; + constexpr index_t NPerBlock = 128; + constexpr index_t KPerBlock = 4; + + constexpr index_t MPerXDL = 32; + constexpr index_t NPerXDL = 32; + constexpr index_t K1 = 8; + + constexpr index_t MRepeat = 1; + constexpr index_t NRepeat = 2; + + using ABlockTransferThreadSliceLengths_K0_M_K1 = Sequence<1, 1, 8>; + using ABlockTransferThreadClusterLengths_K0_M_K1 = Sequence<4, 64, 1>; + + constexpr index_t ABlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t ABlockTransferDstScalarPerVector_K1 = 8; + + using BBlockTransferThreadSliceLengths_K0_N_K1 = Sequence<1, 2, 8>; + using BBlockTransferThreadClusterLengths_K0_N_K1 = Sequence<4, 64, 1>; + + constexpr index_t BBlockTransferSrcScalarPerVector_K1 = 8; + constexpr index_t BBlockTransferDstScalarPerVector_K1 = 8; + + constexpr index_t CThreadTransferDstScalarPerVector = 4; +#endif + + const auto K = a_m_k.mDesc.GetLengths()[1]; + const auto M = a_m_k.mDesc.GetLengths()[0]; + const auto N = b_n_k.mDesc.GetLengths()[0]; + + constexpr auto K1Number = Number{}; + const auto K0 = K / K1Number; + + const auto a_k0_m_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, M, K1Number), + make_tuple(K1Number * a_m_k.mDesc.GetStrides()[1], + a_m_k.mDesc.GetStrides()[0], + a_m_k.mDesc.GetStrides()[1])); + + const auto b_k0_n_k1_grid_desc = + make_naive_tensor_descriptor(make_tuple(K0, N, K1Number), + make_tuple(K1Number * b_n_k.mDesc.GetStrides()[1], + b_n_k.mDesc.GetStrides()[0], + b_n_k.mDesc.GetStrides()[1])); + + const auto c_m_n_grid_desc = make_naive_tensor_descriptor( + make_tuple(M, N), make_tuple(c_n_m.mDesc.GetStrides()[1], c_n_m.mDesc.GetStrides()[0])); + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: M + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: M + Sequence<0>{})); // 2-: K1 + + constexpr auto b_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0>{}, // 0+: K0 + Sequence<0>{}, // 1+: N + Sequence<0>{}), // 2+: K1 + make_tuple(Sequence<0>{}, // 0-: K0 + Sequence<0>{}, // 1-: N + Sequence<0>{})); // 2-: K1 + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops_v2r3, + Sequence<1, 0, 2>, + 2, + ABlockTransferSrcScalarPerVector_K1, + ABlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + Sequence<1, 0, 2>, + Sequence<1, 0, 2>, + 2, + BBlockTransferSrcScalarPerVector_K1, + BBlockTransferDstScalarPerVector_K1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + CThreadTransferDstScalarPerVector, + decltype(a_k0_m_k1_grid_step_hacks), + decltype(b_k0_n_k1_grid_step_hacks), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), + decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), + false // CAccessOrderMRepeatNRepeat + >(static_cast(a_m_k_device_buf.GetDeviceBuffer()), + static_cast(b_n_k_device_buf.GetDeviceBuffer()), + static_cast(c_n_m_device_buf.GetDeviceBuffer()), + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m_n_grid_desc, + a_k0_m_k1_grid_step_hacks, + b_k0_n_k1_grid_step_hacks, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + a_k0_m_k1_grid_move_slice_window_step_hacks, + b_k0_n_k1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast((std::size_t(2) * M * N * K)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + // copy result back to host + c_n_m_device_buf.FromDevice(c_n_m.mData.data()); +} diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index 2bf8adba84..91ea24f947 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -1,5 +1,5 @@ -#ifndef DRIVER_GEMM_XDLOPS_V2R3 -#define DRIVER_GEMM_XDLOPS_V2R3 +#ifndef DRIVER_GEMM_XDLOPS_V2R3_HPP +#define DRIVER_GEMM_XDLOPS_V2R3_HPP #include "common_header.hpp" #include "tensor_descriptor.hpp" @@ -46,13 +46,17 @@ template + bool CAccessOrderMRepeatNRepeat, + bool ABlockLdsAddExtraM, + bool BBlockLdsAddExtraN> __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, const FloatAB* p_b_grid, FloatC* p_c_grid, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, const CMNGridDesc& c_m_n_grid_desc, + ck::index_t M01, + ck::index_t N01, AGridStepHacks, BGridStepHacks, CGridStepHacks, @@ -108,7 +112,9 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, CGridStepHacks, AGridMoveSliceWindowStepHacks, BGridMoveSliceWindowStepHacks, - CAccessOrderMRepeatNRepeat>; + CAccessOrderMRepeatNRepeat, + ABlockLdsAddExtraM, + BBlockLdsAddExtraN>; { std::cout << "a_k0_m_k1_grid_desc{" << a_k0_m_k1_grid_desc.GetLength(I0) << ", " @@ -123,7 +129,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; } - if(!GridwiseGemm::CheckValidity(a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc)) + if(!GridwiseGemm::CheckValidity( + a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); @@ -134,7 +141,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); - const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); + const auto c_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01); using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); diff --git a/host/driver_offline/src/conv_bwd_driver_offline.cpp b/host/driver_offline/src/conv_bwd_driver_offline.cpp index 4e93ada859..366b5dffbc 100644 --- a/host/driver_offline/src/conv_bwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_bwd_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -14,15 +15,16 @@ #include "device_tensor.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp" #include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1.hpp" #define USE_MODE 1 -#define USE_CONV_BWD_V4R1_XDL_NHWC 1 +#define USE_CONV_BWD_V4R1_XDL_NHWC 0 #define USE_CONV_BWD_V4R1R2_XDL_NHWC 1 enum ConvBackwardDataAlgo { - V4R1XDLNHWC, - V4R1R2XDLNHWC, + V4R1XDLNHWC, // 0 + V4R1R2XDLNHWC, // 1 }; int main(int argc, char* argv[]) @@ -280,20 +282,43 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nhwc(); - device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( - tmp[I0], - tmp[I1], - tmp[I2], - tmp[I3], - tmp[I4], - tmp[I5], - tmp[I6], - in_device, - wei, - out, - nrepeat); + if(Y == 1 && X == 1 && in_left_pad_h == 0 && in_left_pad_w == 0 && in_right_pad_h == 0 && + in_right_pad_w == 0) + { + device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk_1x1< + in_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); + } + else + { +#if 1 + device_convolution_backward_data_implicit_gemm_v4r1r2_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in_device, + wei, + out, + nrepeat); +#endif + } } #endif diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 34d7247f3c..48eba2b372 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -24,7 +25,7 @@ #define USE_CONV_FWD_V4R4R2_NHWC 0 #define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 #define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp index 13c73abf30..310dbfe1eb 100644 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -111,7 +112,7 @@ int main(int argc, char* argv[]) constexpr auto Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + I1; #endif -#if 1 +#if 0 using in_data_t = float; using acc_data_t = float; using out_data_t = float; diff --git a/host/driver_offline/src/gemm_driver_offline.cpp b/host/driver_offline/src/gemm_driver_offline.cpp index 42c69ff6a2..e60b4905ae 100644 --- a/host/driver_offline/src/gemm_driver_offline.cpp +++ b/host/driver_offline/src/gemm_driver_offline.cpp @@ -5,6 +5,7 @@ #include #include #include "config.hpp" +#include "debug.hpp" #include "print.hpp" #include "device.hpp" #include "host_tensor.hpp" @@ -16,11 +17,19 @@ #include "device_gemm_xdlops_mk_nk_mn.hpp" #include "device_gemm_xdlops_km_kn_mn.hpp" #include "device_gemm_xdlops_km_nk_mn.hpp" +#include "device_gemm_xdlops_mk_kn_nm.hpp" +#include "device_gemm_xdlops_mk_nk_nm.hpp" +#include "device_gemm_xdlops_km_kn_nm.hpp" +#include "device_gemm_xdlops_km_nk_nm.hpp" #define USE_GEMM_XDL_MK_KN_MN 1 #define USE_GEMM_XDL_MK_NK_MN 1 #define USE_GEMM_XDL_KM_KN_MN 1 #define USE_GEMM_XDL_KM_NK_MN 1 +#define USE_GEMM_XDL_MK_KN_NM 0 +#define USE_GEMM_XDL_MK_NK_NM 0 +#define USE_GEMM_XDL_KM_KN_NM 0 +#define USE_GEMM_XDL_KM_NK_NM 0 enum GemmAlgo { @@ -28,21 +37,21 @@ enum GemmAlgo Xdl_MK_NK_MN, // 1 Xdl_KM_KN_MN, // 2 Xdl_KM_NK_MN, // 3 + Xdl_MK_KN_NM, // 4 + Xdl_MK_NK_NM, // 5 + Xdl_KM_KN_NM, // 6 + Xdl_KM_NK_NM, // 7 }; int main(int argc, char* argv[]) { using namespace ck; - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - - // dynamic mode - if(argc != 10) + if(argc != 12) { printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: M, N, K\n"); + printf("debug_driver_gemm_xdlops_v2r3::M01, debug_driver_gemm_xdlops_v2r3::N01\n"); exit(1); } @@ -57,6 +66,9 @@ int main(int argc, char* argv[]) const index_t N = std::stoi(argv[8]); const index_t K = std::stoi(argv[9]); + debug::debug_driver_gemm_xdlops_v2r3::M01 = std::stoi(argv[10]); + debug::debug_driver_gemm_xdlops_v2r3::N01 = std::stoi(argv[11]); + #if 0 using ab_data_t = float; using acc_data_t = float; @@ -74,69 +86,44 @@ int main(int argc, char* argv[]) std::vector a_lengths_host(2), b_lengths_host(2), c_lengths_host(2); std::vector a_strides_host(2), b_strides_host(2), c_strides_host(2); - if(layout == GemmMatrixLayout::MK_KN_MN) + // A + if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::MK_NK_MN || + layout == GemmMatrixLayout::MK_KN_NM || layout == GemmMatrixLayout::MK_NK_NM) { a_lengths_host[0] = static_cast(M); a_lengths_host[1] = static_cast(K); a_strides_host[0] = static_cast(K); a_strides_host[1] = static_cast(1); - - b_lengths_host[0] = static_cast(K); - b_lengths_host[1] = static_cast(N); - b_strides_host[0] = static_cast(N); - b_strides_host[1] = static_cast(1); - - c_lengths_host[0] = static_cast(M); - c_lengths_host[1] = static_cast(N); - c_strides_host[0] = static_cast(N); - c_strides_host[1] = static_cast(1); } - else if(layout == GemmMatrixLayout::MK_NK_MN) - { - a_lengths_host[0] = static_cast(M); - a_lengths_host[1] = static_cast(K); - a_strides_host[0] = static_cast(K); - a_strides_host[1] = static_cast(1); - - b_lengths_host[0] = static_cast(N); - b_lengths_host[1] = static_cast(K); - b_strides_host[0] = static_cast(K); - b_strides_host[1] = static_cast(1); - - c_lengths_host[0] = static_cast(M); - c_lengths_host[1] = static_cast(N); - c_strides_host[0] = static_cast(N); - c_strides_host[1] = static_cast(1); - } - else if(layout == GemmMatrixLayout::KM_KN_MN) + else { a_lengths_host[0] = static_cast(K); a_lengths_host[1] = static_cast(M); a_strides_host[0] = static_cast(M); a_strides_host[1] = static_cast(1); - - b_lengths_host[0] = static_cast(K); - b_lengths_host[1] = static_cast(N); - b_strides_host[0] = static_cast(N); - b_strides_host[1] = static_cast(1); - - c_lengths_host[0] = static_cast(M); - c_lengths_host[1] = static_cast(N); - c_strides_host[0] = static_cast(N); - c_strides_host[1] = static_cast(1); } - else if(layout == GemmMatrixLayout::KM_NK_MN) - { - a_lengths_host[0] = static_cast(K); - a_lengths_host[1] = static_cast(M); - a_strides_host[0] = static_cast(M); - a_strides_host[1] = static_cast(1); + // B + if(layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN || + layout == GemmMatrixLayout::MK_NK_NM || layout == GemmMatrixLayout::KM_NK_NM) + { b_lengths_host[0] = static_cast(N); b_lengths_host[1] = static_cast(K); b_strides_host[0] = static_cast(K); b_strides_host[1] = static_cast(1); + } + else + { + b_lengths_host[0] = static_cast(K); + b_lengths_host[1] = static_cast(N); + b_strides_host[0] = static_cast(N); + b_strides_host[1] = static_cast(1); + } + // C + if(layout == GemmMatrixLayout::MK_KN_MN || layout == GemmMatrixLayout::KM_KN_MN || + layout == GemmMatrixLayout::MK_NK_MN || layout == GemmMatrixLayout::KM_NK_MN) + { c_lengths_host[0] = static_cast(M); c_lengths_host[1] = static_cast(N); c_strides_host[0] = static_cast(N); @@ -144,7 +131,10 @@ int main(int argc, char* argv[]) } else { - std::runtime_error("wrong! not implemented"); + c_lengths_host[0] = static_cast(N); + c_lengths_host[1] = static_cast(M); + c_strides_host[0] = static_cast(M); + c_strides_host[1] = static_cast(1); } Tensor a(a_lengths_host, a_strides_host); @@ -185,38 +175,6 @@ int main(int argc, char* argv[]) b.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}, num_thread); } - auto f_make_for_device_mk_kn_mn = [&]() { - const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); - const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); - const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); - - return make_tuple(a_desc, b_desc, c_desc); - }; - - auto f_make_for_device_mk_nk_mn = [&]() { - const auto a_desc = make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(K, I1)); - const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); - const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); - - return make_tuple(a_desc, b_desc, c_desc); - }; - - auto f_make_for_device_km_kn_mn = [&]() { - const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); - const auto b_desc = make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(N, I1)); - const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); - - return make_tuple(a_desc, b_desc, c_desc); - }; - - auto f_make_for_device_km_nk_mn = [&]() { - const auto a_desc = make_naive_tensor_descriptor(make_tuple(K, M), make_tuple(M, I1)); - const auto b_desc = make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(K, I1)); - const auto c_desc = make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(N, I1)); - - return make_tuple(a_desc, b_desc, c_desc); - }; - #if USE_GEMM_XDL_MK_KN_MN if(algo == GemmAlgo::Xdl_MK_KN_MN) { @@ -225,10 +183,7 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! layout"); } - const auto descs = f_make_for_device_mk_kn_mn(); - - device_gemm_xdlops_mk_kn_mn( - descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + device_gemm_xdlops_mk_kn_mn(a, b, c_device, nrepeat); } #endif @@ -240,10 +195,7 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! layout"); } - const auto descs = f_make_for_device_mk_nk_mn(); - - device_gemm_xdlops_mk_nk_mn( - descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + device_gemm_xdlops_mk_nk_mn(a, b, c_device, nrepeat); } #endif @@ -255,10 +207,7 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! layout"); } - const auto descs = f_make_for_device_km_kn_mn(); - - device_gemm_xdlops_km_kn_mn( - descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); + device_gemm_xdlops_km_kn_mn(a, b, c_device, nrepeat); } #endif @@ -270,10 +219,55 @@ int main(int argc, char* argv[]) throw std::runtime_error("wrong! layout"); } - const auto descs = f_make_for_device_km_nk_mn(); + device_gemm_xdlops_km_nk_mn(a, b, c_device, nrepeat); + } +#endif - device_gemm_xdlops_km_nk_mn( - descs[I0], descs[I1], descs[I2], a, b, c_device, nrepeat); +#if USE_GEMM_XDL_MK_KN_NM + if(algo == GemmAlgo::Xdl_MK_KN_NM) + { + if(layout != GemmMatrixLayout::MK_KN_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_mk_kn_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_MK_NK_NM + if(algo == GemmAlgo::Xdl_MK_NK_NM) + { + if(layout != GemmMatrixLayout::MK_NK_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_mk_nk_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_KN_NM + if(algo == GemmAlgo::Xdl_KM_KN_NM) + { + if(layout != GemmMatrixLayout::KM_KN_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_km_kn_nm(a, b, c_device, nrepeat); + } +#endif + +#if USE_GEMM_XDL_KM_NK_NM + if(algo == GemmAlgo::Xdl_KM_NK_NM) + { + if(layout != GemmMatrixLayout::KM_NK_NM) + { + throw std::runtime_error("wrong! layout"); + } + + device_gemm_xdlops_km_nk_nm(a, b, c_device, nrepeat); } #endif diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp index e2cba94100..9b66f24f7a 100644 --- a/host/host_tensor/include/device.hpp +++ b/host/host_tensor/include/device.hpp @@ -2,6 +2,8 @@ #define DEVICE_HPP #include +#include +#include #include "hip/hip_runtime.h" #include "hip/hip_fp16.h" @@ -74,6 +76,8 @@ float launch_and_time_kernel( timer.End(); + // std::this_thread::sleep_for (std::chrono::microseconds(10)); + return timer.GetElapsedTime() / nrepeat; } diff --git a/host/host_tensor/include/gemm_common.hpp b/host/host_tensor/include/gemm_common.hpp index f0f35a78b9..f6c0d6f930 100644 --- a/host/host_tensor/include/gemm_common.hpp +++ b/host/host_tensor/include/gemm_common.hpp @@ -7,6 +7,10 @@ enum GemmMatrixLayout MK_NK_MN, // 1 KM_KN_MN, // 2 KM_NK_MN, // 3 + MK_KN_NM, // 4 + MK_NK_NM, // 5 + KM_KN_NM, // 6 + KM_NK_NM, // 7 }; #endif diff --git a/host/host_tensor/include/host_gemm.hpp b/host/host_tensor/include/host_gemm.hpp index 97cf245054..c582a34258 100644 --- a/host/host_tensor/include/host_gemm.hpp +++ b/host/host_tensor/include/host_gemm.hpp @@ -80,6 +80,78 @@ void host_gemm(const Tensor& a, make_ParallelTensorFunctor(f_km_nk_mn, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( std::thread::hardware_concurrency()); } + else if(layout == GemmMatrixLayout::MK_KN_NM) + { + auto f_mk_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(k, n)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_mk_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::MK_NK_NM) + { + auto f_mk_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[1]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(m, k)) * static_cast(b(n, k)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_mk_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_KN_NM) + { + auto f_km_kn_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(k, n)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_km_kn_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } + else if(layout == GemmMatrixLayout::KM_NK_NM) + { + auto f_km_nk_nm = [&](auto n, auto m) { + const int K = a.mDesc.GetLengths()[0]; + + double v = 0; + + for(int k = 0; k < K; ++k) + { + v += static_cast(a(k, m)) * static_cast(b(n, k)); + } + + c(n, m) = v; + }; + + make_ParallelTensorFunctor(f_km_nk_nm, c.mDesc.GetLengths()[0], c.mDesc.GetLengths()[1])( + std::thread::hardware_concurrency()); + } else { throw std::runtime_error("wrong! not supported layout"); diff --git a/script/docker-rocm4.3.1.sh b/script/docker-rocm4.3.1.sh new file mode 100755 index 0000000000..48cb675b69 --- /dev/null +++ b/script/docker-rocm4.3.1.sh @@ -0,0 +1,14 @@ +WORKSPACE=$1 +echo "workspace: " $WORKSPACE + +docker run \ +-it \ +--rm \ +--privileged \ +--group-add sudo \ +-w /root/workspace \ +-v $WORKSPACE:/root/workspace \ +rocm/tensorflow:rocm4.3.1-tf2.6-dev \ +/bin/bash + +#--network host \ diff --git a/script/run.sh b/script/run.sh index 3b383fcf3a..1ff56b2295 100755 --- a/script/run.sh +++ b/script/run.sh @@ -4,24 +4,12 @@ export ROCR_VISIBLE_DEVICE=0 export GPU_DEVICE_ORDINAL=0 -## Boost - export LD_LIBRARY_PATH=/usr/local/lib:$LD_LIBRARY_PATH - -## Compiling -#export OLC_DEBUG_HIP_VERBOSE=1 -#export OLC_DEBUG_HIP_DUMP=1 -#export OLC_DEBUG_SAVE_TEMP_DIR=1 - -#rm -rf /root/_hip_binary_kernels_/ -#rm -rf /tmp/olCompile* - -#make -j conv_fwd_driver_offline + make -j conv_fwd_driver_offline #make -j conv_bwd_driver_offline #make -j conv_wrw_driver_offline -#make -j conv_fwd_driver_online - - make -j gemm_driver_offline +#make -j gemm_driver_offline +DRIVER="./host/driver_offline/conv_fwd_driver_offline" LAYOUT=$1 ALGO=$2 VERIFY=$3 @@ -29,30 +17,121 @@ INIT=$4 LOG=$5 REPEAT=$6 -################################################ layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 +#M01=$7 +#N01=$8 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 + KBATCH=$7 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 -#./host/driver_offline/conv_fwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 7 17 17 1 1 1 1 0 3 0 3 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 -#./host/driver_offline/conv_bwd_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 192 3 3 35 35 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 -#./host/driver_offline/conv_wrw_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 -#./host/driver_online/conv_fwd_driver_online $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 192 3 3 71 71 2 2 1 1 1 1 1 1 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 -################################################ layout algo verify init log repeat M___ N___ K___ -#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 -#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 - ./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 -#./host/driver_offline/gemm_driver_offline $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 128 3 3 14 14 1 1 1 1 1 1 1 1 + +######### layout algo verify init log repeat M___ N___ K___ +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 960 1024 1024 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 1920 2048 2048 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 3840 4096 4096 $M01 $N01 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 7680 8192 8192 $M01 $N01 + +# Resnet50 +######### layout algo verify init log repeat N__ K___ C___ Y X Hi_ Wi__ Strides Dilations LeftPads RightPads + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 28 28 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 128 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 128 3 3 58 58 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 14 14 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 256 3 3 30 30 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 256 1 1 56 56 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 256 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 16 16 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 128 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 512 1 1 28 28 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 512 512 3 3 7 7 1 1 1 1 1 1 1 1 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 256 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 1 1 56 56 1 1 1 1 0 0 0 0 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 256 64 64 3 3 56 56 1 1 1 1 1 1 1 1 + +# 256x128x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH + + + +# 128x128x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 1024 1 1 14 14 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 1024 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 1024 1 1 14 14 1 1 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 28 28 1 1 1 1 1 1 1 1 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 128 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 128 3 3 58 58 2 2 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 2048 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 256 1 1 14 14 1 1 1 1 0 0 0 0 56 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 14 14 1 1 1 1 1 1 1 1 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 256 3 3 30 30 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 256 1 1 56 56 1 1 1 1 0 0 0 0 448 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 256 1 1 56 56 2 2 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 16 16 2 2 1 1 0 0 0 0 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 1024 512 1 1 28 28 2 2 1 1 0 0 0 0 28 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 128 512 1 1 28 28 1 1 1 1 0 0 0 0 224 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 512 1 1 28 28 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 2048 512 1 1 7 7 1 1 1 1 0 0 0 0 14 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 512 512 3 3 7 7 1 1 1 1 1 1 1 1 7 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 $KBATCH + + +# 128x64x32 c64 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 + +# 64x128x32 c64 + $DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 $KBATCH + +# 64x64x32 c32 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 256 1 1 56 56 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 256 64 1 1 56 56 1 1 1 1 0 0 0 0 112 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 1 1 56 56 1 1 1 1 0 0 0 0 448 +#$DRIVER $LAYOUT $ALGO $VERIFY $INIT $LOG $REPEAT 128 64 64 3 3 56 56 1 1 1 1 1 1 1 1 448 From b2dc55f82c635dc9a0a512ca3f476e1d825b0a8c Mon Sep 17 00:00:00 2001 From: Qianfeng Date: Thu, 7 Oct 2021 03:43:17 +0800 Subject: [PATCH 12/15] [MIOpen Downstream] Fix Reduction Kernel (#34) * Tiny fix in using data type template parameters in blockwise and direct_threadwise kernel * Fix with regard to implementing GetZeroVal() in both kernel and host * Avoid convert to compType from dstDataType before writting the output value * Add half_t support to NumericLimits and make constexpr GetZeroVal() of binary operator * Add CONSTANT decorator for descriptor read buffer * Use get_thread_local_1d_id() for thread local Id * Rename GetZeroVal() to GetReductionZeroVal() in the kernels * Remove constexpr from initialized zeroVal and tiny fix in reduction_operator.hpp * Occasional tiny simplification and update in the kernel files * Update to re-order tensor dimensions on the host, split second_call kernel wrapper files and simplify reduce_all kernel wrappers * Update to remove OpenCL tidy checking failures * Update for better readability * Remove unused codes and not-needed template parameters in the kernel wrappers Co-authored-by: Chao Liu --- ...ridwise_generic_2d_reduction_blockwise.hpp | 38 ++- ...generic_2d_reduction_direct_threadwise.hpp | 40 ++-- ...e_generic_2d_reduction_direct_warpwise.hpp | 36 ++- ...idwise_generic_2d_reduction_multiblock.hpp | 4 +- .../reduction_functions_blockwise.hpp | 4 +- .../reduction_functions_warpwise.hpp | 12 +- .../include/utility/data_type.hpp | 27 ++- .../include/utility/reduction_common.hpp | 59 +---- .../include/utility/reduction_enums.hpp | 66 ++++++ .../include/utility/reduction_operator.hpp | 65 +++-- ...n_first_call_blockwise_reduce_all_dims.cpp | 88 ++----- ...rst_call_blockwise_reduce_partial_dims.cpp | 39 +-- ..._first_call_multiblock_reduce_all_dims.cpp | 89 ++----- ...st_call_multiblock_reduce_partial_dims.cpp | 41 ++-- ..._first_call_threadwise_reduce_all_dims.cpp | 90 ++----- ...st_call_threadwise_reduce_partial_dims.cpp | 41 ++-- ...on_first_call_warpwise_reduce_all_dims.cpp | 91 ++----- ...irst_call_warpwise_reduce_partial_dims.cpp | 41 ++-- ..._second_call_blockwise_reduce_all_dims.cpp | 205 ++++++++++++++++ ...nd_call_blockwise_reduce_partial_dims.cpp} | 43 +--- ...second_call_threadwise_reduce_all_dims.cpp | 222 ++++++++++++++++++ ...d_call_threadwise_reduce_partial_dims.cpp} | 45 +--- ...n_second_call_warpwise_reduce_all_dims.cpp | 221 +++++++++++++++++ ...ond_call_warpwise_reduce_partial_dims.cpp} | 45 +--- 24 files changed, 1031 insertions(+), 621 deletions(-) create mode 100644 composable_kernel/include/utility/reduction_enums.hpp create mode 100644 composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp rename composable_kernel/src/kernel_wrapper/{gridwise_generic_reduction_second_call_blockwise.cpp => gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp} (87%) create mode 100644 composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp rename composable_kernel/src/kernel_wrapper/{gridwise_generic_reduction_second_call_threadwise.cpp => gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp} (87%) create mode 100644 composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp rename composable_kernel/src/kernel_wrapper/{gridwise_generic_reduction_second_call_warpwise.cpp => gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp} (87%) diff --git a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp index 20075526b2..c635da57f4 100644 --- a/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_generic_2d_reduction_blockwise.hpp @@ -92,7 +92,7 @@ struct GridwiseReduction_xy_to_x_blockwise // LDS __shared__ compType p_in_block_buffer[BlockBufferSize]; - auto zeroVal = opReduce::GetZeroVal(); + const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise threadwise_dst_load.Run( dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -281,7 +285,7 @@ struct GridwiseReduction_xy_to_x_blockwise ThreadClusterLengths, Sequence<0, 1>, srcDataType, - dstDataType, + compType, src2dDescType, decltype(in_block_desc), Sequence<0, 1>, @@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3(ws_values_global, @@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3 @@ -200,7 +204,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise { (void)ws_indices_global; - const auto zeroVal = opReduce::GetZeroVal(); + const auto zeroVal = opReduce::GetReductionZeroVal(); const auto src_global_buf = make_dynamic_buffer( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -232,7 +236,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); auto threadwise_src_load = ThreadwiseTensorSliceTransfer_v2{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3(ws_values_global, @@ -377,7 +385,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise index_t thread_global_1d_id = get_block_1d_id() * BlockSize + get_thread_local_1d_id(); auto threadwise_src_val_load = ThreadwiseTensorSliceTransfer_v2{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise threadwise_dst_load.Run( dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf(I0) * beta); + dstValue_buf(I0) += priorDstValue_buf(I0) * beta; } auto threadwise_dst_store = - ThreadwiseTensorSliceTransfer_v1r3( p_src_global, src2dDesc.GetElementSpaceSize(), type_convert{}(zeroVal)); @@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3(ws_values_global, @@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise if(!float_equal_one{}(alpha)) accuValue_buf(I0) *= type_convert{}(alpha); + StaticBuffer dstValue_buf; + + dstValue_buf(I0) = type_convert{}(accuValue_buf[I0]); + if(!float_equal_zero{}(beta)) { auto threadwise_dst_load = @@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise make_tuple(I0), priorDstValue_buf); - accuValue_buf(I0) += type_convert{}(priorDstValue_buf[I0] * beta); + dstValue_buf(I0) += priorDstValue_buf[I0] * beta; } auto threadwise_dst_val_store = - ThreadwiseTensorSliceTransfer_v1r3{}( [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); @@ -84,7 +84,7 @@ struct WarpReduce // since for fp16, built-in shuffling functions is not provided by HIP __device__ static void ReduceImpl2(const BufferType& thread_buffer, compType& accuData) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); static_for<0, ThreadBufferLen, 1>{}( [&](auto I) { binop::calculate(lAccuData, thread_buffer[I]); }); @@ -138,7 +138,7 @@ struct WarpReduce int& accuIndex, int indexStart) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); int lAccuIndex = 0; index_t thread_inwarp_id = get_thread_local_1d_id() % warpSize; @@ -170,7 +170,7 @@ struct WarpReduce int& accuIndex, int indexStart) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); int lAccuIndex = 0; index_t thread_id = get_thread_local_1d_id(); index_t warpId = thread_id / warpSize; @@ -278,7 +278,7 @@ struct WarpReduceWithIndicesInput compType& accuData, int& accuIndex) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); int lAccuIndex = 0; static_for<0, ThreadBufferLen, 1>{}([&](auto I) { @@ -307,7 +307,7 @@ struct WarpReduceWithIndicesInput compType& accuData, int& accuIndex) { - compType lAccuData = opReduce::GetZeroVal(); + compType lAccuData = opReduce::GetReductionZeroVal(); int lAccuIndex = 0; index_t thread_id = get_thread_local_1d_id(); index_t warpId = thread_id / warpSize; diff --git a/composable_kernel/include/utility/data_type.hpp b/composable_kernel/include/utility/data_type.hpp index bfaac8a939..07eceb84cf 100644 --- a/composable_kernel/include/utility/data_type.hpp +++ b/composable_kernel/include/utility/data_type.hpp @@ -1008,20 +1008,27 @@ struct inner_product_with_conversion }; template -struct NumericLimits; +struct NumericLimits +{ + __host__ __device__ static constexpr T Min() { return std::numeric_limits::min(); } + + __host__ __device__ static constexpr T Max() { return std::numeric_limits::max(); } + + __host__ __device__ static constexpr T Lowest() { return std::numeric_limits::lowest(); } +}; template <> -struct NumericLimits +struct NumericLimits { - __host__ __device__ static constexpr int32_t Min() - { - return std::numeric_limits::min(); - } + static constexpr unsigned short binary_min = 0x0400; + static constexpr unsigned short binary_max = 0x7BFF; + static constexpr unsigned short binary_lowest = 0xFBFF; - __host__ __device__ static constexpr int32_t Max() - { - return std::numeric_limits::max(); - } + __host__ __device__ static constexpr half_t Min() { return as_type(binary_min); } + + __host__ __device__ static constexpr half_t Max() { return as_type(binary_max); } + + __host__ __device__ static constexpr half_t Lowest() { return as_type(binary_lowest); } }; } // namespace ck diff --git a/composable_kernel/include/utility/reduction_common.hpp b/composable_kernel/include/utility/reduction_common.hpp index 139a18c2a4..ff574c315c 100644 --- a/composable_kernel/include/utility/reduction_common.hpp +++ b/composable_kernel/include/utility/reduction_common.hpp @@ -26,76 +26,25 @@ #ifndef CK_REDUCTION_COMMON_HPP #define CK_REDUCTION_COMMON_HPP -// this enumerate should be synchronized with include/miopen/reduce_common.hpp +#include "reduction_enums.hpp" + namespace ck { -enum class ReductionMethod_t -{ - DirectThreadWise = 1, - DirectWarpWise = 2, - BlockWise = 3, - MultiBlock = 4 -}; // end of namespace ck - -enum class ReduceTensorOp_t -{ - ADD = 0, - MUL = 1, - MIN = 2, - MAX = 3, - AMAX = 4, - AVG = 5, - NORM1 = 6, - NORM2 = 7, - // MUL_NO_ZEROS = 8, -}; - -enum class NanPropagation_t -{ - NOT_PROPAGATE_NAN = 0, - PROPAGATE_NAN = 1, -}; - -enum class ReduceTensorIndices_t -{ - NO_INDICES = 0, - FLATTENED_INDICES = 1, -}; - -enum class IndicesType_t -{ - INDICES_32BIT = 0, - INDICES_64BIT = 1, - INDICES_16BIT = 2, - INDICES_8BIT = 3, -}; struct float_equal_one { - template - __device__ static inline bool apply(T x) - { - return x <= type_convert{}(1.0f) and x >= type_convert{}(1.0f); - } - template __device__ inline bool operator()(T x) { - return (float_equal_one::apply(x)); + return x <= static_cast(1.0f) and x >= static_cast(1.0f); }; }; struct float_equal_zero { - template - __device__ static inline bool apply(T x) - { - return x <= type_convert{}(0.0f) and x >= type_convert{}(0.0f); - } - template __device__ inline bool operator()(T x) { - return (float_equal_zero::apply(x)); + return x <= static_cast(0.0f) and x >= static_cast(0.0f); }; }; diff --git a/composable_kernel/include/utility/reduction_enums.hpp b/composable_kernel/include/utility/reduction_enums.hpp new file mode 100644 index 0000000000..e97108179e --- /dev/null +++ b/composable_kernel/include/utility/reduction_enums.hpp @@ -0,0 +1,66 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2020 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#ifndef CK_REDUCTION_ENUMS_HPP +#define CK_REDUCTION_ENUMS_HPP + +namespace ck { + +enum class ReduceTensorOp_t +{ + ADD = 0, + MUL = 1, + MIN = 2, + MAX = 3, + AMAX = 4, + AVG = 5, + NORM1 = 6, + NORM2 = 7, + // MUL_NO_ZEROS = 8, +}; + +enum class NanPropagation_t +{ + NOT_PROPAGATE_NAN = 0, + PROPAGATE_NAN = 1, +}; + +enum class ReduceTensorIndices_t +{ + NO_INDICES = 0, + FLATTENED_INDICES = 1, +}; + +enum class IndicesType_t +{ + INDICES_32BIT = 0, + INDICES_64BIT = 1, + INDICES_16BIT = 2, + INDICES_8BIT = 3, +}; + +}; // end of namespace ck + +#endif diff --git a/composable_kernel/include/utility/reduction_operator.hpp b/composable_kernel/include/utility/reduction_operator.hpp index 269671a400..c0afbec869 100644 --- a/composable_kernel/include/utility/reduction_operator.hpp +++ b/composable_kernel/include/utility/reduction_operator.hpp @@ -35,10 +35,12 @@ namespace reduce { // Every binary operator used in reduction is represented by a templated functor class. Each functor // class must provide at least // three members: -// 1) GetZeroVal() -- the interface to return the "identity element" for the binary operator, -// "identity element" is the unique +// 1) GetReductionZeroVal() -- the interface to return the "identity element" for the binary +// operator, "identity element" is the unique // element in the algebraic space that doesn't affect the value of other elements -// when operated with any of them. +// when operated against them, and the concept is similar to zero vector in +// vector space +// (http://pages.cs.wisc.edu/~matthewb/pages/notes/pdf/linearalgebra/VectorSpaces.pdf). // 2) indexable -- boolean value indicating whether indices of the operated elements could be // recorded. Usually, Min/Max operator could // need to record the indices of elements. For operator like Add/Mul, no need to @@ -58,7 +60,7 @@ struct Add { using dataType = T; - __device__ static T GetZeroVal() { return type_convert{}(0.0f); }; + __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } @@ -70,7 +72,7 @@ struct Mul { using dataType = T; - __device__ static T GetZeroVal() { return type_convert{}(1.0f); }; + __device__ static constexpr T GetReductionZeroVal() { return static_cast(1.0f); }; __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } @@ -82,7 +84,7 @@ struct Max { using dataType = T; - __device__ static T GetZeroVal() { return std::numeric_limits::min(); }; + __device__ static constexpr T GetReductionZeroVal() { return NumericLimits::Lowest(); }; __device__ inline constexpr void operator()(T& a, T b) const { @@ -107,7 +109,7 @@ struct Min { using dataType = T; - __device__ static T GetZeroVal() { return std::numeric_limits::max(); }; + __device__ static constexpr T GetReductionZeroVal() { return NumericLimits::Max(); }; __device__ inline constexpr void operator()(T& a, T b) const { @@ -127,16 +129,29 @@ struct Min static constexpr bool indexable = true; }; -template <> -__device__ half_t Max::GetZeroVal() +template +struct AMax { - return type_convert{}(std::numeric_limits::min()); -}; + using dataType = T; -template <> -__device__ half_t Min::GetZeroVal() -{ - return type_convert{}(std::numeric_limits::max()); + __device__ static constexpr T GetReductionZeroVal() { return static_cast(0.0f); }; + + __device__ inline constexpr void operator()(T& a, T b) const + { + if(a < b) + a = b; + } + + __device__ inline constexpr void operator()(T& a, T b, bool& changed) const + { + if(a < b) + { + a = b; + changed = true; + } + } + + static constexpr bool indexable = true; }; // Unary operators are usually called element-wisely before the reduction is executed on the @@ -268,7 +283,7 @@ struct unary_sqrt // The templated struct reduce_binary_operator maps the enum Ids of binary operators to their // respective functor classes. -// The "GetZeroVal()" interface and boolean member "indexable" are also provided in +// The "GetReductionZeroVal()" interface and boolean member "indexable" are also provided in // reduce_binary_operactor for // easier checking by the upper-layer codes in the kernels. @@ -281,8 +296,6 @@ struct reduce_binary_operator using opType = reduce::Add; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Add::GetZeroVal(); }; - static constexpr bool indexable = reduce::Add::indexable; }; @@ -292,8 +305,6 @@ struct reduce_binary_operator using opType = reduce::Mul; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Mul::GetZeroVal(); }; - static constexpr bool indexable = reduce::Mul::indexable; }; @@ -303,8 +314,6 @@ struct reduce_binary_operator using opType = reduce::Min; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Min::GetZeroVal(); }; - static constexpr bool indexable = reduce::Min::indexable; }; @@ -314,19 +323,15 @@ struct reduce_binary_operator using opType = reduce::Max; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Max::GetZeroVal(); }; - static constexpr bool indexable = reduce::Max::indexable; }; template struct reduce_binary_operator { - using opType = reduce::Max; + using opType = reduce::AMax; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Max::GetZeroVal(); }; - static constexpr bool indexable = reduce::Max::indexable; }; @@ -336,8 +341,6 @@ struct reduce_binary_operator using opType = reduce::Add; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Add::GetZeroVal(); }; - static constexpr bool indexable = reduce::Add::indexable; }; @@ -347,8 +350,6 @@ struct reduce_binary_operator using opType = reduce::Add; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Add::GetZeroVal(); }; - static constexpr bool indexable = reduce::Add::indexable; }; @@ -358,8 +359,6 @@ struct reduce_binary_operator using opType = reduce::Add; using dataType = T; - __device__ static T GetZeroVal() { return reduce::Add::GetZeroVal(); }; - static constexpr bool indexable = reduce::Add::indexable; }; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp index e16010dee1..ca6b415910 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_all_dims.cpp @@ -43,9 +43,6 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -using toReduceDims = Sequence; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge, toReduceDims>::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, void* __restrict__ ws_global) { (void)GridSize; @@ -132,18 +107,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); const auto one_dim_srcDesc = transform_tensor_descriptor( srcDesc, @@ -157,14 +128,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); + constexpr int invariantLen = 1; + const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; @@ -179,30 +144,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; }; -template +template struct get_ref_desc_types { static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; // don't have to use accurate strides to get an expected referrence type static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); + static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( ref_srcDesc, @@ -217,12 +180,6 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); @@ -235,25 +192,22 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}, Sequence<1>{}))); using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, + decltype(transform_tensor_descriptor(ref_dstDesc, make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}))); using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types::refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_34; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; -template +template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) { if constexpr(need_padding) @@ -277,15 +231,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp index cba7ffe295..a3daeaf163 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_blockwise_reduce_partial_dims.cpp @@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; +constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; +constexpr index_t num_invariantDims = srcDims - num_toReduceDims; + +using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; +using toReduceDims = typename arithmetic_sequence_gen::type; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); +static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, int outStride0, int outStride1, int outStride2, @@ -133,14 +122,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; const int dstStrides[6] = { outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); + const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); @@ -179,16 +166,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; }; @@ -278,15 +265,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp index 34b877027c..81899dfb02 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_all_dims.cpp @@ -43,10 +43,6 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge, toReduceDims>::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, void* __restrict__ ws_global) { (void)GridSize; @@ -132,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); const auto one_dim_srcDesc = transform_tensor_descriptor( srcDesc, @@ -157,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); + constexpr int invariantLen = 1; + const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; const index_t reduceSizePerBlock = @@ -181,30 +145,28 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; }; -template +template struct get_ref_desc_types { static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; // don't have to use accurate strides to get an expected referrence type static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); + static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( ref_srcDesc, @@ -219,12 +181,6 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); @@ -237,23 +193,20 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}, Sequence<1>{}))); using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, + decltype(transform_tensor_descriptor(ref_dstDesc, make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}))); using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_34 = - typename get_ref_desc_types::refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types::refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_34; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -279,16 +232,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_dst_global; (void)indices_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp index 9c7318dc15..0e578f4d1d 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_multiblock_reduce_partial_dims.cpp @@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; +constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; +constexpr index_t num_invariantDims = srcDims - num_toReduceDims; + +using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; +using toReduceDims = typename arithmetic_sequence_gen::type; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); +static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, int outStride0, int outStride1, int outStride2, @@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; const int dstStrides[6] = { outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); + const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); @@ -180,16 +167,16 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; }; @@ -279,16 +266,16 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_dst_global; (void)indices_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp index 8e67d1faa1..e63a1254e4 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_all_dims.cpp @@ -43,9 +43,6 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -using toReduceDims = Sequence; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge, toReduceDims>::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, void* __restrict__ ws_global) { (void)BlkGroupSize; @@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); const auto one_dim_srcDesc = transform_tensor_descriptor( srcDesc, @@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); + constexpr int invariantLen = 1; + const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); constexpr auto copySliceLen = GredThreadBufferLength; @@ -178,12 +143,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -191,31 +156,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, { const auto dstPad = GridSize * BlockSize - invariantLen; auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, + transform_tensor_descriptor(dstdDesc, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; } }; -template +template struct get_ref_desc_types { static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; // don't have to use accurate strides to get an expected referrence type static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); + static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( ref_srcDesc, @@ -230,12 +193,6 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); @@ -248,23 +205,20 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}, Sequence<1>{}))); using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, + decltype(transform_tensor_descriptor(ref_dstDesc, make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}))); using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_12 = - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types::refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -290,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp index fdbcda64ba..698f740058 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_threadwise_reduce_partial_dims.cpp @@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; +constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; +constexpr index_t num_invariantDims = srcDims - num_toReduceDims; + +using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; +using toReduceDims = typename arithmetic_sequence_gen::type; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); +static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, int outStride0, int outStride1, int outStride2, @@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; const int dstStrides[6] = { outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); + const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); @@ -178,12 +165,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -195,12 +182,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; } }; @@ -291,15 +278,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp index 8aa1376c3a..4a607372e9 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_all_dims.cpp @@ -43,9 +43,6 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; -constexpr index_t dstDims = CK_PARAM_OUT_DIMS; - -using toReduceDims = Sequence; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -58,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge, toReduceDims>::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -110,18 +97,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, - int outStride0, - int outStride1, - int outStride2, - int outStride3, - int outStride4, - int outStride5, void* __restrict__ ws_global) { (void)BlkGroupSize; @@ -131,18 +106,14 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; - const int dstStrides[6] = { - outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); - const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); - const auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); const auto one_dim_srcDesc = transform_tensor_descriptor( srcDesc, @@ -156,14 +127,8 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - auto dst1dDesc = transform_tensor_descriptor( - dstDesc, - make_tuple(make_merge_transform(tupleDstLengths)), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - - const auto invariantLen = src2dDesc.GetLength(Number<0>{}); - const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); + constexpr int invariantLen = 1; + const auto toReduceLen = src2dDesc.GetLength(Number<1>{}); constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; @@ -179,12 +144,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -192,31 +157,29 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, { const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; auto dst1dDesc_2 = - transform_tensor_descriptor(dst1dDesc, + transform_tensor_descriptor(dstDesc, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) - *static_cast(p_dst1dDesc) = dst1dDesc; + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; } }; -template +template struct get_ref_desc_types { static constexpr auto ref_srcLengths = typename uniform_sequence_gen::type{}; - static constexpr auto ref_dstLengths = typename uniform_sequence_gen::type{}; // don't have to use accurate strides to get an expected referrence type static constexpr auto ref_srcDesc = make_naive_tensor_descriptor( make_tuple_from_seq(ref_srcLengths), make_tuple_from_seq(ref_srcLengths)); - static constexpr auto ref_dstDesc = make_naive_tensor_descriptor( - make_tuple_from_seq(ref_dstLengths), make_tuple_from_seq(ref_dstLengths)); + static constexpr auto ref_dstDesc = make_naive_tensor_descriptor(make_tuple(1), make_tuple(1)); static constexpr auto ref_one_dim_srcDesc = transform_tensor_descriptor( ref_srcDesc, @@ -231,12 +194,6 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}), make_tuple(Sequence<0, 1>{})); - static constexpr auto ref_dst1dDesc = transform_tensor_descriptor( - ref_dstDesc, - make_tuple(make_merge_transform(make_tuple_from_seq(ref_dstLengths))), - make_tuple(typename arithmetic_sequence_gen<0, dstDims, 1>::type{}), - make_tuple(Sequence<0>{})); - static constexpr auto ref_invariantLen = ref_src2dDesc.GetLength(Number<0>{}); static constexpr auto ref_toReduceLen = ref_src2dDesc.GetLength(Number<1>{}); @@ -249,23 +206,19 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}, Sequence<1>{}))); using refType_dst1dDesc_padded = - decltype(transform_tensor_descriptor(ref_dst1dDesc, + decltype(transform_tensor_descriptor(ref_dstDesc, make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{}))); using refType_src2dDesc = decltype(ref_src2dDesc); - using refType_dst1dDesc = decltype(ref_dst1dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; -using refType_src2dDesc_padded_12 - typename get_ref_desc_types::refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types::refType_dst1dDesc_padded; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc_padded_12 typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -291,15 +244,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp index e18d623fe5..a641527900 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_first_call_warpwise_reduce_partial_dims.cpp @@ -45,8 +45,11 @@ constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; +constexpr index_t num_toReduceDims = CK_PARAM_NUM_TOREDUCE_DIMS; +constexpr index_t num_invariantDims = srcDims - num_toReduceDims; + +using invariantDims = typename arithmetic_sequence_gen<0, num_invariantDims, 1>::type; +using toReduceDims = typename arithmetic_sequence_gen::type; constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 @@ -59,15 +62,7 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); +static_assert(num_invariantDims > 0, "Not all dimensins are reduced for this kernel !!"); constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -111,12 +106,6 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, int inStride3, int inStride4, int inStride5, - int outLength0, - int outLength1, - int outLength2, - int outLength3, - int outLength4, - int outLength5, int outStride0, int outStride1, int outStride2, @@ -132,14 +121,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, const int srcLengths[6] = {inLength0, inLength1, inLength2, inLength3, inLength4, inLength5}; const int srcStrides[6] = {inStride0, inStride1, inStride2, inStride3, inStride4, inStride5}; - const int dstLengths[6] = { - outLength0, outLength1, outLength2, outLength3, outLength4, outLength5}; const int dstStrides[6] = { outStride0, outStride1, outStride2, outStride3, outStride4, outStride5}; const auto tupleSrcLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleSrcStrides = make_tuple_from_array(srcStrides, Number{}); - const auto tupleDstLengths = make_tuple_from_array(dstLengths, Number{}); + const auto tupleDstLengths = make_tuple_from_array(srcLengths, Number{}); const auto tupleDstStrides = make_tuple_from_array(dstStrides, Number{}); const auto srcDesc = make_naive_tensor_descriptor(tupleSrcLengths, tupleSrcStrides); @@ -179,12 +166,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -196,12 +183,12 @@ extern "C" __global__ void gridwise_generic_reduce_1_prepare(int GridSize, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; } }; @@ -292,15 +279,15 @@ extern "C" __global__ void gridwise_generic_reduce_1(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)BlkGroupSize; (void)ws_buf2_bytes_offset; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp new file mode 100644 index 0000000000..7e9d46612e --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_all_dims.cpp @@ -0,0 +1,205 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "config.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "tensor_descriptor_helper.hpp" +#include "data_type_enum_helper.hpp" +#include "reduction_common.hpp" +#include "gridwise_generic_2d_reduction_blockwise.hpp" + +using namespace ck; + +using srcDataType = + typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; +using dstDataType = + typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; +using compType = + typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable + +constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); +constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 + ? NanPropagation_t::NOT_PROPAGATE_NAN + : NanPropagation_t::PROPAGATE_NAN; +constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 + ? ReduceTensorIndices_t::NO_INDICES + : ReduceTensorIndices_t::FLATTENED_INDICES; + +constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); +constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); + +constexpr bool indexable = reduce_binary_operator::indexable; +constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); + +constexpr index_t GredAccessesPerThreadInBlock = CK_PARAM_ACCESSES_PER_THREAD_INBLOCK; // tunable + +extern "C" __global__ void +gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) +{ + (void)GridSize; + + void* p_src2dDesc = ws_global; + void* p_dst1dDesc = static_cast(ws_global) + 2048; + + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); + + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + const index_t invariantLen = dstDesc.GetLength(Number<0>{}); + const index_t toReduceLen = BlkGroupSize; + + auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); + + constexpr auto copySliceLen = BlockSize * GredAccessesPerThreadInBlock; + + if constexpr(src2d_need_padding) + { + const auto srcPad = + ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; + + auto src2dDesc_2 = + transform_tensor_descriptor(src2dDesc, + make_tuple(make_pass_through_transform(invariantLen), + make_pad_transform(toReduceLen, 0, srcPad)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc; + } + + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; +}; + +struct get_ref_desc_types +{ + static constexpr auto ref_tupleDstLengths = make_tuple(8); + static constexpr auto ref_dstDesc = + make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); + + static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); + static constexpr index_t ref_toReduceLen = 8; + + static constexpr auto ref_src2dDesc = + make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); + + using refType_src2dDesc = decltype(ref_src2dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); + + // used by the BlockWise and MultiBlock method + using refType_src2dDesc_padded_34 = decltype( + transform_tensor_descriptor(ref_src2dDesc, + make_tuple(make_pass_through_transform(ref_invariantLen), + make_pad_transform(ref_toReduceLen, 0, 2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}))); + + using refType_dst1dDesc_padded = + decltype(transform_tensor_descriptor(ref_dstDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{}))); +}; + +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc_padded_34 = typename get_ref_desc_types::refType_src2dDesc_padded_34; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; + +template +static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_src2dDesc)); + else + return (*reinterpret_cast(p_src2dDesc)); +}; + +template +static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_dst1dDesc)); + else + return (*reinterpret_cast(p_dst1dDesc)); +}; + +extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, + float alpha, + const void* __restrict__ p_src_global, + float beta, + void* __restrict__ p_dst_global, + const void CONSTANT* ws_global, + long ws_buf2_bytes_offset, + void* __restrict__ indices_global) +{ + (void)p_src_global; + + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); + + const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); + const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); + + using gridwise_2d_reduce = GridwiseReduction_xy_to_x_blockwise; + + void* const ws_buf2_global = + ws_buf2_bytes_offset > 0 + ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) + : nullptr; + + constexpr int RunId = need_indices ? 3 : 1; + gridwise_2d_reduce::template Run( + src2dDesc, + dst1dDesc, + origReduceLen, + alpha, + static_cast(ws_buf1_global), + beta, + static_cast(p_dst_global), + static_cast(ws_buf2_global), + static_cast(indices_global)); +}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp similarity index 87% rename from composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise.cpp rename to composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp index b7b58cbb90..3f37d01e21 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_blockwise_reduce_partial_dims.cpp @@ -42,12 +42,8 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable -constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty - constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 ? NanPropagation_t::NOT_PROPAGATE_NAN @@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -152,20 +138,20 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; }; -template +template struct get_ref_desc_types { static constexpr auto ref_tupleDstLengths = @@ -203,16 +189,11 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}))); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_34 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_34; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_34; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -237,15 +218,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_src_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp new file mode 100644 index 0000000000..77841d1312 --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_all_dims.cpp @@ -0,0 +1,222 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "config.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "tensor_descriptor_helper.hpp" +#include "data_type_enum_helper.hpp" +#include "reduction_common.hpp" +#include "gridwise_generic_2d_reduction_direct_threadwise.hpp" + +using namespace ck; + +using srcDataType = + typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; +using dstDataType = + typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; +using compType = + typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable + +using toReduceDims = Sequence; +using invariantDims = Sequence; // this could be empty + +constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); +constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 + ? NanPropagation_t::NOT_PROPAGATE_NAN + : NanPropagation_t::PROPAGATE_NAN; +constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 + ? ReduceTensorIndices_t::NO_INDICES + : ReduceTensorIndices_t::FLATTENED_INDICES; + +constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); +constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); + +constexpr bool indexable = reduce_binary_operator::indexable; +constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); + +constexpr index_t GredThreadBufferLength = CK_PARAM_THREAD_BUFFER_LENGTH; // tunable + +extern "C" __global__ void +gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) +{ + (void)BlkGroupSize; + + void* p_src2dDesc = ws_global; + void* p_dst1dDesc = static_cast(ws_global) + 2048; + + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); + + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + const index_t invariantLen = dstDesc.GetLength(Number<0>{}); + const index_t toReduceLen = BlkGroupSize; + + auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); + + constexpr auto copySliceLen = GredThreadBufferLength; + + if constexpr(src2d_need_padding) + { + const auto srcPad1 = GridSize * BlockSize - invariantLen; + const auto srcPad2 = + ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; + auto src2dDesc_2 = + transform_tensor_descriptor(src2dDesc, + make_tuple(make_pad_transform(invariantLen, 0, srcPad1), + make_pad_transform(toReduceLen, 0, srcPad2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc; + } + + if constexpr(dst1d_need_padding) + { + const auto dstPad = GridSize * BlockSize - invariantLen; + auto dst1dDesc_2 = + transform_tensor_descriptor(dstDesc, + make_tuple(make_pad_transform(invariantLen, 0, dstPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dst1dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; + } +}; + +struct get_ref_desc_types +{ + static constexpr auto ref_tupleDstLengths = make_tuple(8); + static constexpr auto ref_dstDesc = + make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); + + static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); + static constexpr index_t ref_toReduceLen = 8; + + static constexpr auto ref_src2dDesc = + make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); + + using refType_src2dDesc = decltype(ref_src2dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); + + // used by the DirectThreadWise and DirectWarpWise method + using refType_src2dDesc_padded_12 = + decltype(transform_tensor_descriptor(ref_src2dDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2), + make_pad_transform(ref_toReduceLen, 0, 2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}))); + + using refType_dst1dDesc_padded = + decltype(transform_tensor_descriptor(ref_dstDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{}))); +}; + +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; + +template +static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_src2dDesc)); + else + return (*reinterpret_cast(p_src2dDesc)); +}; + +template +static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_dst1dDesc)); + else + return (*reinterpret_cast(p_dst1dDesc)); +}; + +extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, + float alpha, + const void* __restrict__ p_src_global, + float beta, + void* __restrict__ p_dst_global, + const void CONSTANT* ws_global, + long ws_buf2_bytes_offset, + void* __restrict__ indices_global) +{ + (void)p_src_global; + + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); + + const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); + const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); + + using gridwise_2d_reduce = GridwiseReduction_xy_to_x_direct_threadwise; + + void* const ws_buf2_global = + ws_buf2_bytes_offset > 0 + ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) + : nullptr; + + constexpr int RunId = need_indices ? 3 : 1; + gridwise_2d_reduce::template Run( + src2dDesc, + dst1dDesc, + origReduceLen, + alpha, + static_cast(ws_buf1_global), + beta, + static_cast(p_dst_global), + static_cast(ws_buf2_global), + static_cast(indices_global)); +}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp similarity index 87% rename from composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise.cpp rename to composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp index ef88547028..2de461ad0f 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_threadwise_reduce_partial_dims.cpp @@ -42,12 +42,8 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable -constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty - constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 ? NanPropagation_t::NOT_PROPAGATE_NAN @@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -152,12 +138,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -169,17 +155,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; } }; -template +template struct get_ref_desc_types { static constexpr auto ref_tupleDstLengths = @@ -217,16 +203,11 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}))); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_12 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -251,15 +232,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_src_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp new file mode 100644 index 0000000000..1ba5e49657 --- /dev/null +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_all_dims.cpp @@ -0,0 +1,221 @@ +/******************************************************************************* + * + * MIT License + * + * Copyright (c) 2021 Advanced Micro Devices, Inc. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + *******************************************************************************/ +#include "config.hpp" +#include "number.hpp" +#include "sequence.hpp" +#include "tensor_descriptor_helper.hpp" +#include "data_type_enum_helper.hpp" +#include "reduction_common.hpp" +#include "gridwise_generic_2d_reduction_direct_warpwise.hpp" + +using namespace ck; + +using srcDataType = + typename get_datatype_from_enum(CK_PARAM_SRC_DATATYPE)>::type; +using dstDataType = + typename get_datatype_from_enum(CK_PARAM_DST_DATATYPE)>::type; +using compType = + typename get_datatype_from_enum(CK_PARAM_REDUCE_COMPTYPE)>::type; + +constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable + +constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); +constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 + ? NanPropagation_t::NOT_PROPAGATE_NAN + : NanPropagation_t::PROPAGATE_NAN; +constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 + ? ReduceTensorIndices_t::NO_INDICES + : ReduceTensorIndices_t::FLATTENED_INDICES; + +constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); +constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); + +constexpr bool indexable = reduce_binary_operator::indexable; +constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); + +constexpr index_t GredAccessesPerThreadInWarp = CK_PARAM_ACCESSES_PER_THREAD_INWARP; // tunable + +extern "C" __global__ void +gridwise_generic_reduce_2_prepare(int GridSize, int BlkGroupSize, void* __restrict__ ws_global) +{ + (void)BlkGroupSize; + + void* p_src2dDesc = ws_global; + void* p_dst1dDesc = static_cast(ws_global) + 2048; + + const auto tupleDstLengths = make_tuple(1); + const auto tupleDstStrides = make_tuple(1); + + auto dstDesc = make_naive_tensor_descriptor(tupleDstLengths, tupleDstStrides); + + const index_t invariantLen = dstDesc.GetLength(Number<0>{}); + const index_t toReduceLen = BlkGroupSize; + + auto src2dDesc = make_naive_tensor_descriptor_packed(make_tuple(invariantLen, toReduceLen)); + + constexpr auto copySliceLen = warpSize * GredAccessesPerThreadInWarp; + + if constexpr(src2d_need_padding) + { + const auto srcPad1 = GridSize * BlockSize / warpSize - invariantLen; + const auto srcPad2 = + ((toReduceLen + copySliceLen - 1) / copySliceLen) * copySliceLen - toReduceLen; + + auto src2dDesc_2 = + transform_tensor_descriptor(src2dDesc, + make_tuple(make_pad_transform(invariantLen, 0, srcPad1), + make_pad_transform(toReduceLen, 0, srcPad2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_src2dDesc) = src2dDesc; + } + + if constexpr(dst1d_need_padding) + { + const auto dstPad = GridSize * BlockSize / warpSize - invariantLen; + auto dst1dDesc_2 = + transform_tensor_descriptor(dstDesc, + make_tuple(make_pad_transform(invariantLen, 0, dstPad)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{})); + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dst1dDesc_2; + } + else + { + if(get_thread_local_1d_id() == 0) + *static_cast(p_dst1dDesc) = dstDesc; + } +}; + +struct get_ref_desc_types +{ + static constexpr auto ref_tupleDstLengths = make_tuple(8); + static constexpr auto ref_dstDesc = + make_naive_tensor_descriptor(ref_tupleDstLengths, ref_tupleDstLengths); + + static constexpr index_t ref_invariantLen = ref_dstDesc.GetLength(Number<0>{}); + static constexpr index_t ref_toReduceLen = 8; + + static constexpr auto ref_src2dDesc = + make_naive_tensor_descriptor_packed(make_tuple(ref_invariantLen, ref_toReduceLen)); + + using refType_src2dDesc = decltype(ref_src2dDesc); + using refType_dst1dDesc = decltype(ref_dstDesc); + + // used by the DirectThreadWise and DirectWarpWise method + using refType_src2dDesc_padded_12 = + decltype(transform_tensor_descriptor(ref_src2dDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2), + make_pad_transform(ref_toReduceLen, 0, 2)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}))); + + using refType_dst1dDesc_padded = + decltype(transform_tensor_descriptor(ref_dstDesc, + make_tuple(make_pad_transform(ref_invariantLen, 0, 2)), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0>{}))); +}; + +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc_padded_12 = typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; + +template +static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_src2dDesc)); + else + return (*reinterpret_cast(p_src2dDesc)); +}; + +template +static __device__ auto get_reduction_dst1d_descriptor(const void* p_dst1dDesc) +{ + if constexpr(need_padding) + return (*reinterpret_cast(p_dst1dDesc)); + else + return (*reinterpret_cast(p_dst1dDesc)); +}; + +extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, + float alpha, + const void* __restrict__ p_src_global, + float beta, + void* __restrict__ p_dst_global, + const void CONSTANT* ws_global, + long ws_buf2_bytes_offset, + void* __restrict__ indices_global) +{ + (void)p_src_global; + + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); + + const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); + const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); + + using gridwise_2d_reduce = + GridwiseReduction_xy_to_x_direct_warpwise; + + void* const ws_buf2_global = + ws_buf2_bytes_offset > 0 + ? static_cast(static_cast(ws_buf1_global) + ws_buf2_bytes_offset) + : nullptr; + + constexpr int RunId = need_indices ? 3 : 1; + gridwise_2d_reduce::template Run( + src2dDesc, + dst1dDesc, + origReduceLen, + alpha, + static_cast(ws_buf1_global), + beta, + static_cast(p_dst_global), + static_cast(ws_buf2_global), + static_cast(indices_global)); +}; diff --git a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise.cpp b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp similarity index 87% rename from composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise.cpp rename to composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp index 53b0e1e759..aef1545f11 100644 --- a/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise.cpp +++ b/composable_kernel/src/kernel_wrapper/gridwise_generic_reduction_second_call_warpwise_reduce_partial_dims.cpp @@ -42,12 +42,8 @@ using compType = constexpr index_t BlockSize = CK_PARAM_BLOCKSIZE; // tunable -constexpr index_t srcDims = CK_PARAM_IN_DIMS; constexpr index_t dstDims = CK_PARAM_OUT_DIMS; -using toReduceDims = Sequence; -using invariantDims = Sequence; // this could be empty - constexpr ReduceTensorOp_t op = static_cast(CK_PARAM_REDUCE_OP); constexpr NanPropagation_t nanPropaOpt = CK_PARAM_NAN_PROPAGATE == 0 ? NanPropagation_t::NOT_PROPAGATE_NAN @@ -59,16 +55,6 @@ constexpr ReduceTensorIndices_t reduceIndicesOpt = CK_PARAM_REDUCE_INDICES == 0 constexpr bool src2d_need_padding = static_cast(CK_PARAM_SRC2D_PADDING); constexpr bool dst1d_need_padding = static_cast(CK_PARAM_DST1D_PADDING); -//////////////////////////////////////////////////////////////////////////////////////// -using specDims = typename sequence_merge::type; - -static_assert(is_valid_sequence_map::value && specDims::Size() == srcDims, - "Wrong invariant and/or toReduce dimensions!"); - -// The number of invariant dimensions can be zero if all dimension are to be reduced -static_assert(invariantDims::Size() > 0 || dstDims == 1, - "If all source dimensions are reduced, the dest should have only one dimension !!"); - constexpr bool indexable = reduce_binary_operator::indexable; constexpr bool need_indices = indexable && (reduceIndicesOpt != ReduceTensorIndices_t::NO_INDICES); @@ -153,12 +139,12 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_pad_transform(toReduceLen, 0, srcPad2)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_src2dDesc) = src2dDesc; } @@ -170,17 +156,17 @@ extern "C" __global__ void gridwise_generic_reduce_2_prepare(int GridSize, make_tuple(make_pad_transform(invariantLen, 0, dstPad)), make_tuple(Sequence<0>{}), make_tuple(Sequence<0>{})); - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc_2; } else { - if(hipThreadIdx_x == 0) + if(get_thread_local_1d_id() == 0) *static_cast(p_dst1dDesc) = dst1dDesc; } }; -template +template struct get_ref_desc_types { static constexpr auto ref_tupleDstLengths = @@ -218,16 +204,11 @@ struct get_ref_desc_types make_tuple(Sequence<0>{}))); }; -using refType_src2dDesc = - typename get_ref_desc_types::refType_src2dDesc; -using refType_dst1dDesc = - typename get_ref_desc_types::refType_dst1dDesc; +using refType_src2dDesc = typename get_ref_desc_types::refType_src2dDesc; +using refType_dst1dDesc = typename get_ref_desc_types::refType_dst1dDesc; using refType_src2dDesc_padded_12 = - typename get_ref_desc_types:: - refType_src2dDesc_padded_12; -using refType_dst1dDesc_padded = - typename get_ref_desc_types:: - refType_dst1dDesc_padded; + typename get_ref_desc_types::refType_src2dDesc_padded_12; +using refType_dst1dDesc_padded = typename get_ref_desc_types::refType_dst1dDesc_padded; template static __device__ auto get_reduction_src2d_descriptor(const void* p_src2dDesc) @@ -252,15 +233,15 @@ extern "C" __global__ void gridwise_generic_reduce_2(int origReduceLen, const void* __restrict__ p_src_global, float beta, void* __restrict__ p_dst_global, - void* __restrict__ ws_global, + const void CONSTANT* ws_global, long ws_buf2_bytes_offset, void* __restrict__ indices_global) { (void)p_src_global; - const void* p_src2dDesc = ws_global; - const void* p_dst1dDesc = static_cast(ws_global) + 2048; - void* ws_buf1_global = static_cast(ws_global) + 4096; + const void* p_src2dDesc = cast_pointer_to_generic_address_space(ws_global); + const void* p_dst1dDesc = static_cast(p_src2dDesc) + 2048; + void* ws_buf1_global = const_cast(static_cast(p_src2dDesc) + 4096); const auto src2dDesc = get_reduction_src2d_descriptor(p_src2dDesc); const auto dst1dDesc = get_reduction_dst1d_descriptor(p_dst1dDesc); From fd49ff8080b90687108c46f92321ce10ecc743dc Mon Sep 17 00:00:00 2001 From: ltqin Date: Wed, 20 Oct 2021 07:42:34 +0800 Subject: [PATCH 13/15] add nchw atomic , nhwc and nhwc atomic method for backward weight (#30) * add add new algorithm from v4r4r2 * program once issue * add split k functiion * redefine code * add a matrix unmerge * add b matrix unmerge k0 * trans a and b to gridegemm * nhwc init * no hacks and vector load * add hacks * modify some parameter * fix tuning prometer for fp32 * fix tuning prometer for fp16 * start change gridwise k split * init ok * revome a b matrix k0mk1 desc in grid * carewrite lculate gridsize * add kbatch to CalculateBottomIndex * remove some unused funtion * add clear data function before call kernel * out hacks * in hacks * rename device convolution file and function name * modify kBatch value * fix some tuning code * start from v4r4 nhwc * nhwc atomic is able to run * just for fp32 * enable nchw atomic * tweak * tweak * re-arrange gridwise gemm hot loop for wrw * add wrw v4r5 * v4r4r5 fp16 * v4r4r4 fp16 * v4r4r2 fp16 * V4R4R4XDLNHWC fp16 * V4R4R2XDLATOMICNCHW fp16 * adjust for fp16 * input gridsize * change kbatch to gridsize * testing wrw * clean up * k_batch to gridsize * fix bug * wrw v4r4r4 kbatch change to gride size * wrw v4r4r2 kbatch change to gride size * after merge , change gridwise gemm v2r4 * change MakeCBlockClusterAdaptor * other method use new gridwise gemm * clean up * chapad method nge to make_right_pad_transform * kbatch out from transform function * clean up and fix bug * fix bug * using function type reduce template parameters * using auto replace define fuction type * clean up Co-authored-by: ltqin Co-authored-by: Chao Liu Co-authored-by: Jing Zhang --- ...into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp | 147 ++++ ...into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp | 147 ++++ ...lution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp | 132 ++++ ...lution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp | 144 ++++ .../gridwise_gemm_xdlops_v2r4.hpp | 666 ++++++++++++++++++ ...plicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp | 5 +- ...mm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp | 258 +++++++ ...icit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp | 21 +- ...mm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp | 290 ++++++++ ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 276 ++++++++ ...mm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp | 458 ++++++++++++ .../include/driver_gemm_xdlops_v2r4.hpp | 209 ++++++ .../src/conv_wrw_driver_offline.cpp | 168 ++++- host/host_tensor/include/device.hpp | 2 +- .../include/host_tensor_generator.hpp | 11 + 15 files changed, 2914 insertions(+), 20 deletions(-) create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp create mode 100644 composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp create mode 100644 host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp new file mode 100644 index 0000000000..e533ad9188 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp @@ -0,0 +1,147 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R2_ATOMIC_NCHW_KCYX_NKHW_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// GemmM = K +// GemmK = N * Ho * Wo +// GemmN = C * Y * X +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( + const TensorDescriptor& wei_k_c_y_x_grid_desc, + const TensorDescriptor& in_n_c_hi_wi_grid_desc, + const TensorDescriptor& out_n_k_ho_wo_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_c_hi_wi_grid_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_grid_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_grid_desc.GetLength(I1); + + const auto Hi = in_n_c_hi_wi_grid_desc.GetLength(I2); + const auto Wi = in_n_c_hi_wi_grid_desc.GetLength(I3); + + const auto Ho = out_n_k_ho_wo_grid_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_grid_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_grid_desc.GetLength(I2); + const auto X = wei_k_c_y_x_grid_desc.GetLength(I3); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = C * Y * X; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N, K, Ho * Wo)), + make_tuple(make_pass_through_transform(K), make_merge_transform(make_tuple(N, Ho * Wo))), + make_tuple(Sequence<1>{}, Sequence<0, 2>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_c_hip_wip_grid_desc = transform_tensor_descriptor( + in_n_c_hi_wi_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_c_y_ho_x_wo_grid_desc = transform_tensor_descriptor( + in_n_c_hip_wip_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pass_through_transform(C), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2, 3>{}, Sequence<4, 5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_c_y_ho_x_wo_grid_desc, + make_tuple(make_merge_transform(make_tuple(C, Y, X)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 2, 4>{}, Sequence<0, 3, 5>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, C * Y * X)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(C * Y * X)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..213e1d6135 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,147 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_ATOMIC_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: output tensor + const auto out_gemmktotal_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..f1e1826d16 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,132 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R4_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: in +// B: wei +// C: out +// GemmM = N * Ho * Wo +// GemmN = K +// GemmK = Y * X * C +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmK = N * Ho * Wo; + const auto GemmK0 = GemmK / GemmK1; + + // A: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmk_gemmm_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = + transform_tensor_descriptor(in_gemmk_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // B: output tensor + const auto out_gemmk_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)), + make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmk0_gemmn_gemmk1_grid_desc = + transform_tensor_descriptor(out_gemmk_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = transform_tensor_descriptor( + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)), + make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..02e61c0ea3 --- /dev/null +++ b/composable_kernel/include/problem_transform/transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,144 @@ +#ifndef CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP +#define CK_TRANSFORM_BACKWARD_WEIGHT_CONVOLUTION_INTO_GEMM_V4R4R5_NHWC_KYXC_NHWK_HPP + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" + +namespace ck { + +// A: out +// B: in +// C: wei +// GemmM = K +// GemmN = Y * X * C +// GemmKTotal = N * Ho * Wo +template +__host__ __device__ constexpr auto +transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( + const TensorDescriptor& in_n_hi_wi_c_grid_desc, + const TensorDescriptor& wei_k_y_x_c_grid_desc, + const TensorDescriptor& out_n_ho_wo_k_grid_desc, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + Number, + GemmKBatchType GemmKBatch, + GemmKPadType GemmKPad) +{ + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + constexpr auto GemmK1 = Number{}; + + const auto N = in_n_hi_wi_c_grid_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_grid_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_grid_desc.GetLength(I3); + + const auto Hi = in_n_hi_wi_c_grid_desc.GetLength(I1); + const auto Wi = in_n_hi_wi_c_grid_desc.GetLength(I2); + + const auto Ho = out_n_ho_wo_k_grid_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_grid_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_grid_desc.GetLength(I1); + const auto X = wei_k_y_x_c_grid_desc.GetLength(I2); + + const auto ConvStrideH = conv_strides[I0]; + const auto ConvStrideW = conv_strides[I1]; + + const auto ConvDilationH = conv_dilations[I0]; + const auto ConvDilationW = conv_dilations[I1]; + + const auto InLeftPadH = in_left_pads[I0]; + const auto InLeftPadW = in_left_pads[I1]; + + const auto InRightPadH = in_right_pads[I0]; + const auto InRightPadW = in_right_pads[I1]; + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + const index_t GemmK0 = GemmKPad / (GemmKBatch * GemmK1); + + // A: output tensor + const auto out_gemmktotal_gemmm_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K)); + + const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor( + out_gemmktotal_gemmm_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor( + out_gemmkpad_gemmm_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmM)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // B: input tensor + const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor( + in_n_hi_wi_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_pad_transform(Hi, InLeftPadH, InRightPadH), + make_pad_transform(Wi, InLeftPadW, InRightPadW), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{})); + + const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor( + in_n_hip_wip_c_grid_desc, + make_tuple(make_pass_through_transform(N), + make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)), + make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)), + make_pass_through_transform(C)), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{})); + + const auto in_gemmktotal_gemmn_grid_desc = + transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc, + make_tuple(make_merge_transform(make_tuple(Y, X, C)), + make_merge_transform(make_tuple(N, Ho, Wo))), + make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}), + make_tuple(Sequence<1>{}, Sequence<0>{})); + + const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor( + in_gemmktotal_gemmn_grid_desc, + make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0>{}, Sequence<1>{})); + + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor( + in_gemmkpad_gemmn_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(GemmKBatch, GemmK0, GemmK1)), + make_pass_through_transform(GemmN)), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); + + // C: weight tensor + const auto wei_gemmm_gemmn_grid_desc = + make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C)); + + return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc); +} + +} // namespace ck +#endif diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp new file mode 100644 index 0000000000..8a9c932f4c --- /dev/null +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r4.hpp @@ -0,0 +1,666 @@ +#ifndef CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP +#define CK_GRIDWISE_GEMM_XDLOPS_V2R4_HPP + +#include "common_header.hpp" +#include "multi_index_transform_helper.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "blockwise_gemm_xdlops.hpp" +#include "blockwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_transfer.hpp" +#include "threadwise_tensor_slice_set.hpp" + +namespace ck { + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const ABK0MK1GridDesc a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const CBlockClusterAdaptor c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); +} +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS + __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) +#endif + kernel_gemm_xdlops_v2r4(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + const void CONSTANT* p_a_b_k0_m_k1_grid_desc, + const void CONSTANT* p_b_b_k0_n_k1_grid_desc, + const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const void CONSTANT* p_c_block_cluster_adaptor) +{ + constexpr index_t shared_block_size = + GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB); + + const auto a_b_k0_m_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_a_b_k0_m_k1_grid_desc)); + const auto b_b_k0_n_k1_grid_desc = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_b_b_k0_n_k1_grid_desc)); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); + const auto c_block_cluster_adaptor = *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); + + __shared__ FloatAB p_shared_block[shared_block_size]; + + GridwiseGemm::Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); +} +#endif + +template +struct GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4 +{ + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; + + // K1 should be Number<...> + static constexpr auto K1 = Number{}; + + __host__ __device__ static constexpr index_t GetSharedMemoryNumberOfByte() + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + constexpr auto b_block_space_size = + math::integer_least_multiple(b_k0_n_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + return (a_block_space_size + b_block_space_size) * sizeof(FloatAB); + } + + // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} + __host__ __device__ static constexpr bool + CheckValidity(const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + index_t M01, + index_t N01) + { + static_assert(is_known_at_compile_time>::value, + "wrong! K1 need to be known at compile-time"); + + static_assert((MPerBlock % (MPerXDL * MRepeat) == 0) && + (NPerBlock % (NRepeat * NPerXDL)) == 0, + "Invalid tuning param!"); + + const auto M = a_b_k0_m_k1_grid_desc.GetLength(I2); + const auto N = b_b_k0_n_k1_grid_desc.GetLength(I2); + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + if(!(M == c_m_n_grid_desc.GetLength(I0) && N == c_m_n_grid_desc.GetLength(I1) && + K0 == b_b_k0_n_k1_grid_desc.GetLength(I1) && + K1 == a_b_k0_m_k1_grid_desc.GetLength(I3) && + K1 == b_b_k0_n_k1_grid_desc.GetLength(I3) && + KBatch == b_b_k0_n_k1_grid_desc.GetLength(I0))) + return false; + + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) + return false; + + // check M01, N01 + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + if(!(M0 % M01 == 0 && N0 % N01 == 0)) + return false; + + // TODO: also check validity of all components (blockwise-copy, threadwise-copy, etc) + return true; + } + + __host__ __device__ static constexpr index_t + CalculateGridSize(const CMNGridDesc& c_m_n_grid_desc, index_t KBatch) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + const index_t grid_size = (M / MPerBlock) * (N / NPerBlock) * KBatch; + + return grid_size; + } + + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; + + return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + } + + // return block_id to C matrix tile idx (m0, n0) mapping + __host__ __device__ static constexpr auto MakeCBlockClusterAdaptor( + const CMNGridDesc& c_m_n_grid_desc, index_t M01, index_t N01, index_t KBatch) + { + const auto M = c_m_n_grid_desc.GetLength(I0); + const auto N = c_m_n_grid_desc.GetLength(I1); + + constexpr auto M1 = Number{}; + constexpr auto N1 = Number{}; + + const auto M0 = M / M1; + const auto N0 = N / N1; + + const auto M00 = M0 / M01; + const auto N00 = N0 / N01; + + const auto kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_pass_through_transform(KBatch), + make_unmerge_transform(make_tuple(M00, M01)), + make_unmerge_transform(make_tuple(N00, N01))), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 3>{}, Sequence<2, 4>{})); + + const auto c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(KBatch, M00, N00, M01, N01))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto c_blockid_to_kbatch_m0_n0_block_cluster_adaptor = + chain_tensor_adaptors(kbatch_m00_m01_n00_n01_to_m0_n0_block_cluster_adaptor, + c_blockid_to_kbatch_m00_m01_n00_n01_block_cluster_adaptor); + + return c_blockid_to_kbatch_m0_n0_block_cluster_adaptor; + } + + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1, 1)); + + __device__ static void Run(const FloatAB* __restrict__ p_a_grid, + const FloatAB* __restrict__ p_b_grid, + FloatC* __restrict__ p_c_grid, + FloatAB* __restrict__ p_shared_block, + const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + const CBlockClusterAdaptor& c_block_cluster_adaptor) + { + const auto a_grid_buf = make_dynamic_buffer( + p_a_grid, a_b_k0_m_k1_grid_desc.GetElementSpaceSize()); + const auto b_grid_buf = make_dynamic_buffer( + p_b_grid, b_b_k0_n_k1_grid_desc.GetElementSpaceSize()); + auto c_grid_buf = make_dynamic_buffer( + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); + + const auto K0 = a_b_k0_m_k1_grid_desc.GetLength(I1); + + // divide block work by [M, N] + const auto block_work_idx = + c_block_cluster_adaptor.CalculateBottomIndex(make_multi_index(get_block_1d_id())); + + const index_t k_batch_id = block_work_idx[I0]; + // HACK: this force m/n_block_data_idx_on_grid into SGPR + const index_t m_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I1] * MPerBlock); + + const index_t n_block_data_idx_on_grid = + __builtin_amdgcn_readfirstlane(block_work_idx[I2] * NPerBlock); + + // lds max alignment + constexpr auto max_lds_align = K1; + + // A matrix in LDS memory, dst of blockwise copy + constexpr auto a_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto a_b_k0_m_k1_block_desc = [&]() { + if constexpr(ABlockLdsExtraM) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // B matrix in LDS memory, dst of blockwise copy + constexpr auto b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number{}, Number{}, K1), + make_tuple(Number{} * K1, K1, I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); + } + }(); + + constexpr auto b_b_k0_n_k1_block_desc = [&]() { + if constexpr(BBlockLdsExtraN) + { + return make_naive_tensor_descriptor( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + make_tuple(Number{} * Number{} * K1, + Number{} * K1, + K1, + I1)); + } + else + { + return make_naive_tensor_descriptor_aligned( + make_tuple(Number<1>{}, Number{}, Number{}, K1), + max_lds_align); + } + }(); + // A matrix blockwise copy + auto a_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + ABlockTransferThreadSliceLengths_K0_M_K1, + ABlockTransferThreadClusterLengths_K0_M_K1, + ABlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(a_b_k0_m_k1_grid_desc), + decltype(a_b_k0_m_k1_block_desc), + ABlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + ABlockTransferSrcVectorDim, + 3, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_K1, + 1, + 1, + AThreadTransferSrcResetCoordinateAfterRun, + true>( + a_b_k0_m_k1_grid_desc, + make_multi_index(k_batch_id, 0, m_block_data_idx_on_grid, 0), + a_b_k0_m_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // B matrix blockwise copy + auto b_blockwise_copy = + BlockwiseTensorSliceTransfer_v4, + BBlockTransferThreadSliceLengths_K0_N_K1, + BBlockTransferThreadClusterLengths_K0_N_K1, + BBlockTransferThreadClusterArrangeOrder, + FloatAB, + FloatAB, + decltype(b_b_k0_n_k1_grid_desc), + decltype(b_b_k0_n_k1_block_desc), + BBlockTransferSrcAccessOrder, + Sequence<0, 2, 1, 3>, + BBlockTransferSrcVectorDim, + 3, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_K1, + 1, + 1, + BThreadTransferSrcResetCoordinateAfterRun, + true>( + b_b_k0_n_k1_grid_desc, + make_multi_index(k_batch_id, 0, n_block_data_idx_on_grid, 0), + b_b_k0_n_k1_block_desc, + make_multi_index(0, 0, 0, 0)); + + // GEMM definition + // c_mtx += transpose(a_mtx) * b_mtx + // a_mtx[KPerBlock, MPerBlock] is in LDS + // b_mtx[KPerBlock, NPerBlock] is in LDS + // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in + // register + // sanity check + + auto blockwise_gemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; + + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); + + // LDS allocation for A and B: be careful of alignment + constexpr auto a_block_space_size = + math::integer_least_multiple(a_k0_m_k1_block_desc.GetElementSpaceSize(), max_lds_align); + + FloatAB* p_a_block = p_shared_block; + FloatAB* p_b_block = p_shared_block + a_block_space_size; + + constexpr auto a_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(0, KPerBlock, 0, 0); + + // hack to control index calculation when iterating over A and B matrix for threadwise copy + constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; + constexpr auto b_k0_n_k1_grid_step_hacks = BGridStepHacks{}; + + // hack to control index calculation when move slice window for A and B matrix for + // threadwise copy + constexpr auto a_k0_m_k1_grid_move_slice_window_step_hack = AGridMoveSliceWindowStepHacks{}; + constexpr auto b_k0_n_k1_grid_move_slice_window_step_hack = BGridMoveSliceWindowStepHacks{}; + + auto a_block_buf = make_dynamic_buffer( + p_a_block, a_k0_m_k1_block_desc.GetElementSpaceSize()); + auto b_block_buf = make_dynamic_buffer( + p_b_block, b_k0_n_k1_block_desc.GetElementSpaceSize()); + + // preload data into LDS + { + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + } + + // main body + index_t k_block_data_begin = 0; + + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_b_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_step_hack); + + a_blockwise_copy.RunRead(a_b_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + + block_sync_lds(); + + b_blockwise_copy.RunRead(b_b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_b_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_b_k0_n_k1_block_desc, b_block_buf); + + k_block_data_begin += KPerBlock; + } while(k_block_data_begin < (K0 - KPerBlock)); + + // tail + { + block_sync_lds(); + + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + } + + // output: register to global memory + { + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); + + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); + + // calculate origin of thread output tensor on global memory + // blockwise GEMM c matrix starting index + const auto c_thread_mtx_on_block = + blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); + + const index_t m_thread_data_on_grid = + m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; + + const index_t n_thread_data_on_grid = + n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; + + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + + auto c_thread_copy = + ThreadwiseTensorSliceTransfer_v1r3, + CThreadTransferSrcDstAccessOrder, + CThreadTransferSrcDstVectorDim, + CThreadTransferDstScalarPerVector, + CGlobalMemoryDataOperation, + 1, + true>{ + + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2])}; + + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); + } + } +}; // namespace ck + +} // namespace ck +#endif diff --git a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp index b5ff1db296..8258aa0e66 100644 --- a/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_backward_data_implicit_gemm_v4r1_xdlops_nhwc_kyxc_nhwk.hpp @@ -3,6 +3,7 @@ #include "host_tensor.hpp" #include "transform_backward_data_convolution_into_gemm_v4r1_nhwc_kyxc_nhwk.hpp" #include "driver_gemm_xdlops_v2r3.hpp" +#include "debug.hpp" template +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw( + const InLengths& in_n_c_hi_wi_lengths, + const WeiLengths& wei_k_c_y_x_lengths, + const OutLengths& out_n_k_ho_wo_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, + const Tensor& out_n_k_ho_wo, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); + + in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + out_n_k_ho_wo_device_buf.ToDevice(out_n_k_ho_wo.mData.data()); + + const auto in_n_c_hi_wi_desc = make_naive_tensor_descriptor_packed(in_n_c_hi_wi_lengths); + const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); + const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); + +#if 1 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerWave = 32; + constexpr index_t GemmNPerWave = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 2, 8>; + using GemmABlockTransferThreadClusterLengths_GemmB_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 64, 1>; + // using vector load 4, so config's wo*ho must be a multiple of 4 + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 2, 8>; + using GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 1; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_c_hi_wi_desc.GetLength(I0); + const auto C = in_n_c_hi_wi_desc.GetLength(I1); + const auto K = out_n_k_ho_wo_desc.GetLength(I1); + + const auto Ho = out_n_k_ho_wo_desc.GetLength(I2); + const auto Wo = out_n_k_ho_wo_desc.GetLength(I3); + + const auto Y = wei_k_c_y_x_desc.GetLength(I2); + const auto X = wei_k_c_y_x_desc.GetLength(I3); + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + const auto descs = + transform_backward_weight_convolution_into_gemm_v4r4r2_atomic_nchw_kcyx_nkhw_pad( + wei_k_c_y_x_desc, + in_n_c_hi_wi_desc, + out_n_k_ho_wo_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto out_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB + Sequence<0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2+: GemmM + Sequence<0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemB + Sequence<0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0>{}, // 2-: GemmM + Sequence<0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmB + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmB + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmN + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 1, 0, 0, 0, 0>{}; + + constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = + driver_gemm_xdlops_v2r4, + Sequence<0, 2, 1, 3>, + 3, + GemmABlockTransferSrcScalarPerVector_GemmK1, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmB_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmB_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1, 3>, + Sequence<0, 2, 1, 3>, + 3, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<3, 0, 1, 2, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, + true, + true>; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + float perf = static_cast(calculate_convolution_flops( + in_n_c_hi_wi_desc, wei_k_c_y_x_desc, out_n_k_ho_wo_desc)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" << std::endl; + } + + wei_k_c_y_x_device_buf.ToDevice(wei_k_c_y_x.mData.data()); + driver_gemm_xdlops(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + out_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_c_y_x_device_buf.FromDevice(wei_k_c_y_x.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index b8ecfb4be9..ac75c56bf5 100644 --- a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -4,7 +4,8 @@ #include "transform_backward_weight_convolution_into_gemm_v4r4r2_nchw_kcyx_nkhw.hpp" #include "driver_gemm_xdlops_v2r3.hpp" -template & in_n_c_hi_wi, - Tensor& wei_k_c_y_x, + const Tensor& in_n_c_hi_wi, + Tensor& wei_k_c_y_x, const Tensor& out_n_k_ho_wo, ck::index_t nrepeat) { @@ -35,8 +36,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk constexpr auto I1 = Number<1>{}; constexpr auto I2 = Number<2>{}; - DeviceMem in_n_c_hi_wi_device_buf(sizeof(TInWei) * in_n_c_hi_wi.mDesc.GetElementSpace()); - DeviceMem wei_k_c_y_x_device_buf(sizeof(TInWei) * wei_k_c_y_x.mDesc.GetElementSpace()); + DeviceMem in_n_c_hi_wi_device_buf(sizeof(TIn) * in_n_c_hi_wi.mDesc.GetElementSpace()); + DeviceMem wei_k_c_y_x_device_buf(sizeof(TWei) * wei_k_c_y_x.mDesc.GetElementSpace()); DeviceMem out_n_k_ho_wo_device_buf(sizeof(TOut) * out_n_k_ho_wo.mDesc.GetElementSpace()); in_n_c_hi_wi_device_buf.ToDevice(in_n_c_hi_wi.mData.data()); @@ -47,7 +48,7 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk const auto wei_k_c_y_x_desc = make_naive_tensor_descriptor_packed(wei_k_c_y_x_lengths); const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); -#if 1 +#if 0 // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; @@ -164,9 +165,9 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk { float ave_time = driver_gemm_xdlops_v2r3< BlockSize, - TInWei, + TIn, TAcc, - TOut, + TWei, InMemoryDataOperationEnum_t::Set, decltype(out_gemmk0_gemmm_gemmk1_grid_desc), decltype(in_gemmk0_gemmn_gemmk1_grid_desc), @@ -207,8 +208,8 @@ void device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nk true, // ABlockLdsExtraM true // BBlockLdsExtraN >(static_cast(out_n_k_ho_wo_device_buf.GetDeviceBuffer()), - static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), - static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), + static_cast(in_n_c_hi_wi_device_buf.GetDeviceBuffer()), + static_cast(wei_k_c_y_x_device_buf.GetDeviceBuffer()), out_gemmk0_gemmm_gemmk1_grid_desc, in_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmm_gemmn_grid_desc, diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..579c7a1200 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,290 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [128, 256, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_hi_wi_c_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_desc.GetLength(I1); + const auto X = wei_k_y_x_c_desc.GetLength(I2); + + const auto GemmM = Y * X * C; + const auto GemmN = K; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + + const auto descs = + transform_backward_weight_convolution_into_gemm_v4r4r4_atomic_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmKBatch + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 1+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 2+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 3+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmKBatch + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 1-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 2-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 3-: GemmK1 + + constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + constexpr auto out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::AtomicAdd, + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 6, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + driver_gemm_xdlops(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..bc5d599604 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,276 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r3.hpp" +#include "debug.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 2; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 2; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 4] for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; + +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 4>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r4_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}); + + const auto in_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto out_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto in_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{})); // 2-: GemmK1 + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0>{}; + + constexpr auto out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + for(index_t i = 0; i < 5; ++i) + { + float ave_time = driver_gemm_xdlops_v2r3< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::Set, + decltype(in_gemmk0_gemmm_gemmk1_grid_desc), + decltype(out_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 2, 1>, + Sequence<0, 2, 1>, + 1, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(in_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>(static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + in_gemmk0_gemmm_gemmk1_grid_desc, + out_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + in_gemmk0_gemmm_gemmk1_grid_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + out_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + const auto N = out_n_ho_wo_k_lengths[I0]; + const auto K = out_n_ho_wo_k_lengths[I3]; + const auto C = wei_k_y_x_c_lengths[I3]; + + const auto Ho = out_n_ho_wo_k_lengths[I1]; + const auto Wo = out_n_ho_wo_k_lengths[I2]; + + const auto Y = wei_k_y_x_c_lengths[I1]; + const auto X = wei_k_y_x_c_lengths[I2]; + + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp new file mode 100644 index 0000000000..29b404f7d0 --- /dev/null +++ b/host/driver_offline/include/device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp @@ -0,0 +1,458 @@ +#include +#include "device.hpp" +#include "host_tensor.hpp" +#include "transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk.hpp" +#include "driver_gemm_xdlops_v2r4.hpp" + +template +void device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk( + const InLengths& in_n_hi_wi_c_lengths, + const WeiLengths& wei_k_y_x_c_lengths, + const OutLengths& out_n_ho_wo_k_lengths, + const ConvStrides& conv_strides, + const ConvDilations& conv_dilations, + const InLeftPads& in_left_pads, + const InRightPads& in_right_pads, + const Tensor& in_n_hi_wi_c, + Tensor& wei_k_y_x_c, + const Tensor& out_n_ho_wo_k, + GridSizeType desired_grid_size, + ck::index_t nrepeat) +{ + using namespace ck; + + std::cout << __func__ << std::endl; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + DeviceMem in_n_hi_wi_c_device_buf(sizeof(TIn) * in_n_hi_wi_c.mDesc.GetElementSpace()); + DeviceMem wei_k_y_x_c_device_buf(sizeof(TWei) * wei_k_y_x_c.mDesc.GetElementSpace()); + DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); + + in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); + + const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); + const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); + const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); + +#if 0 + // [M, N, K0, K1] = [256, 128, 4, 4], C 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 256, 4, 4], C 128, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 256; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 4; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 128, 4, 4], C 64, for fp32 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 4; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 4; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 4, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 32, 2>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 4; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [256, 128, 4, 8], C 128, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 4; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [128, 128, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 256; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 16, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 0 + // [M, N, K0, K1] = [128, 64, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 128; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 128, 4, 8], C 64, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 128; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 2; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 16, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#elif 1 + // [M, N, K0, K1] = [64, 64, 4, 8], C 32, for fp16 + constexpr index_t BlockSize = 128; + + constexpr index_t GemmMPerBlock = 64; + constexpr index_t GemmNPerBlock = 64; + constexpr index_t GemmKPerBlock = 4; + + constexpr index_t GemmMPerXDL = 32; + constexpr index_t GemmNPerXDL = 32; + constexpr index_t GemmK1 = 8; + + constexpr index_t MRepeat = 2; + constexpr index_t NRepeat = 1; + + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmM = 8; + constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 2; + + using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 1, 8, 2>; + using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 4, 8, 4>; + + constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmN = 8; + constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 2; + + constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; +#endif + + const auto N = in_n_hi_wi_c_desc.GetLength(I0); + const auto C = in_n_hi_wi_c_desc.GetLength(I3); + const auto K = out_n_ho_wo_k_desc.GetLength(I3); + + const auto Ho = out_n_ho_wo_k_desc.GetLength(I1); + const auto Wo = out_n_ho_wo_k_desc.GetLength(I2); + + const auto Y = wei_k_y_x_c_desc.GetLength(I1); + const auto X = wei_k_y_x_c_desc.GetLength(I2); + + const auto GemmM = K; + const auto GemmN = Y * X * C; + const auto GemmKTotal = N * Ho * Wo; + + const auto GemmK = GemmKTotal / GemmK1; + + const auto GridMN = GemmM * GemmN / (GemmMPerBlock * GemmNPerBlock); + const index_t GemmKBatch = std::max(desired_grid_size / GridMN, 1); + const index_t BatchLen = std::ceil(GemmK * 1.0 / (GemmKPerBlock * GemmKBatch)); + const index_t GemmK0 = BatchLen * GemmKPerBlock; + const index_t GemmKPad = GemmKBatch * GemmK0 * GemmK1; + + std::cout << "GemmKTotal: " << GemmKTotal << " GrideSizeMN: " << GridMN + << " GemmKBatch: " << GemmKBatch << " GemmK0: " << GemmK0 << " gemmKPad: " << GemmKPad + << std::endl; + + const auto descs = transform_backward_weight_convolution_into_gemm_v4r4r5_nhwc_kyxc_nhwk_pad( + in_n_hi_wi_c_desc, + wei_k_y_x_c_desc, + out_n_ho_wo_k_desc, + conv_strides, + conv_dilations, + in_left_pads, + in_right_pads, + Number{}, + GemmKBatch, + GemmKPad); + + const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; + const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; + const auto wei_gemmm_gemmn_grid_desc = descs[I2]; + + // HACK: hacks that control index calculation when iterating over A, B, C matrix + constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1+: GemmN + Sequence<0, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0>{}, // 1-: GemmN + Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks = make_tuple( + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}, // 0+: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0>{}), // 2+: GemmK1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{}, // 0-: GemmK0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: GemmM + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0>{})); // 2-: GemmK1 + + constexpr auto wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 + + constexpr auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0>{}; + + constexpr auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0>{}; + + const auto driver_gemm_xdlops = driver_gemm_xdlops_v2r4< + BlockSize, + TIn, + TAcc, + TWei, + InMemoryDataOperationEnum_t::AtomicAdd, + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc), + decltype(wei_gemmm_gemmn_grid_desc), + GemmMPerBlock, + GemmNPerBlock, + GemmKPerBlock, + GemmMPerXDL, + GemmNPerXDL, + GemmK1, + MRepeat, + NRepeat, + GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, + GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 2, 3>, + 2, + GemmABlockTransferSrcScalarPerVector_GemmM, + GemmABlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, + GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, + Sequence<0, 1, 2, 3>, + Sequence<0, 1, 3, 2>, + 2, + GemmBBlockTransferSrcScalarPerVector_GemmN, + GemmBBlockTransferDstScalarPerVector_GemmK1, + false, // don't move back src coordinate after threadwise copy + Sequence<2, 3, 0, 1, 7, 5, 4, 6>, + 7, + GemmCThreadTransferDstScalarPerVector, + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks), + decltype(wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), + decltype(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), + decltype(in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks), + false, // CAccessOrderMRepeatNRepeat + true, + true>; + + // timing + for(index_t i = 0; i < 5; ++i) + { + float ave_time = + driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + nrepeat); + + { + float perf = static_cast((std::size_t(2) * N * K * Ho * Wo * C * Y * X)) / + (std::size_t(1000) * 1000 * 1000) / ave_time; + + std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" + << std::endl; + } + } + + // verification + wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); + driver_gemm_xdlops(static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), + static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), + static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc, + wei_gemmm_gemmn_grid_desc, + debug::debug_driver_gemm_xdlops_v2r3::M01, + debug::debug_driver_gemm_xdlops_v2r3::N01, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_step_hacks, + wei_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks, + out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, + in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, + 0); + // copy result back to host + wei_k_y_x_c_device_buf.FromDevice(wei_k_y_x_c.mData.data()); +} diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp new file mode 100644 index 0000000000..65c4f62367 --- /dev/null +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r4.hpp @@ -0,0 +1,209 @@ +#ifndef DRIVER_GEMM_XDLOPS_V2R4 +#define DRIVER_GEMM_XDLOPS_V2R4 + +#include "common_header.hpp" +#include "tensor_descriptor.hpp" +#include "tensor_descriptor_helper.hpp" +#include "gridwise_gemm_xdlops_v2r4.hpp" + +template +__host__ float driver_gemm_xdlops_v2r4(const FloatAB* p_a_grid, + const FloatAB* p_b_grid, + FloatC* p_c_grid, + const ABK0MK1GridDesc& a_b_k0_m_k1_grid_desc, + const BBK0NK1GridDesc& b_b_k0_n_k1_grid_desc, + const CMNGridDesc& c_m_n_grid_desc, + ck::index_t M01, + ck::index_t N01, + AGridStepHacks, + BGridStepHacks, + CGridStepHacks, + AGridMoveSliceWindowStepHacks, + BGridMoveSliceWindowStepHacks, + ck::index_t nrepeat) + +{ + using namespace ck; + + constexpr auto I0 = Number<0>{}; + constexpr auto I1 = Number<1>{}; + constexpr auto I2 = Number<2>{}; + constexpr auto I3 = Number<3>{}; + + using GridwiseGemm = + GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4; + + { + std::cout << "a_b_k0_m_k1_grid_desc{" << a_b_k0_m_k1_grid_desc.GetLength(I0) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I1) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I2) << ", " + << a_b_k0_m_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "b_b_k0_n_k1_grid_desc{" << b_b_k0_n_k1_grid_desc.GetLength(I0) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I1) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I2) << ", " + << b_b_k0_n_k1_grid_desc.GetLength(I3) << "}" << std::endl; + + std::cout << "c_m_n_grid_desc{ " << c_m_n_grid_desc.GetLength(I0) << ", " + << c_m_n_grid_desc.GetLength(I1) << "}" << std::endl; + } + + if(!GridwiseGemm::CheckValidity( + a_b_k0_m_k1_grid_desc, b_b_k0_n_k1_grid_desc, c_m_n_grid_desc, M01, N01)) + { + throw std::runtime_error( + "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r4 has invalid setting"); + } + + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); + + using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); + + const auto KBatch = a_b_k0_m_k1_grid_desc.GetLength(I0); + + const auto c_block_cluster_adaptor = + GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc, M01, N01, KBatch); + + using CBlockClusterAdaptor = decltype(c_block_cluster_adaptor); + + const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc, KBatch); + { + std::cout << "gridSize : " << grid_size << std::endl; + } + + const auto kernel = kernel_gemm_xdlops_v2r4, + remove_reference_t, + remove_reference_t, + remove_reference_t>; + +#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE + float ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_b_k0_m_k1_grid_desc, + b_b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + +#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER + DeviceMem a_b_k0_m_k1_grid_desc_dev_buf(sizeof(ABK0MK1GridDesc)); + DeviceMem b_b_k0_n_k1_grid_desc_dev_buf(sizeof(BBK0NK1GridDesc)); + DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); + DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); + + a_b_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_b_k0_m_k1_grid_desc); + b_b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_b_k0_n_k1_grid_desc); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); + c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); + + float ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_b_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); +#endif + return ave_time; +} +#endif diff --git a/host/driver_offline/src/conv_wrw_driver_offline.cpp b/host/driver_offline/src/conv_wrw_driver_offline.cpp index 310dbfe1eb..50f4d6a9b3 100644 --- a/host/driver_offline/src/conv_wrw_driver_offline.cpp +++ b/host/driver_offline/src/conv_wrw_driver_offline.cpp @@ -14,13 +14,25 @@ #include "host_conv_bwd_weight.hpp" #include "device_tensor.hpp" #include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk.hpp" +#include "device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk.hpp" #define USE_DYNAMIC_MODE 1 -#define USE_CONV_WRW_V4R4R2_XDL_NCHW 1 +#define USE_CONV_WRW_V4R4R2_XDL_NCHW 0 +#define USE_CONV_WRW_V4R4R4_XDL_NHWC 0 +#define USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW 0 +#define USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC 0 +#define USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC 1 enum ConvBackwardWeightAlgo { - V4R4R2XDLNCHW, + V4R4R2XDLNCHW, // 0 + V4R4R4XDLNHWC, // 1 + V4R4R2XDLATOMICNCHW, // 2 + V4R4R4XDLATOMICNHWC, // 3 + V4R4R5XDLATOMICNHWC, // 4 }; int main(int argc, char* argv[]) @@ -37,10 +49,11 @@ int main(int argc, char* argv[]) #if USE_DYNAMIC_MODE // dynamic mode - if(argc != 22) + if(argc != 23) { printf("arg1 to 6: layout, algo, do_verification, init_method, do_log, nrepeat\n"); printf("rest: N, K, C, Y, X, Hi, Wi, Sy, Sx, Dy, Dx, LeftPy, LeftPx, RightPy, RightPx\n"); + printf("additional: desired_grid_size\n"); exit(1); } @@ -68,6 +81,8 @@ int main(int argc, char* argv[]) const index_t in_right_pad_h = std::stoi(argv[20]); const index_t in_right_pad_w = std::stoi(argv[21]); + const index_t desired_grid_size = std::stoi(argv[22]); + const index_t YEff = (Y - 1) * conv_dilation_h + 1; const index_t XEff = (X - 1) * conv_dilation_w + 1; @@ -114,16 +129,19 @@ int main(int argc, char* argv[]) #if 0 using in_data_t = float; + using wei_data_t = float; using acc_data_t = float; using out_data_t = float; #elif 1 using in_data_t = half_t; - using acc_data_t = float; using out_data_t = half_t; + using acc_data_t = float; + using wei_data_t = float; #elif 1 using in_data_t = int8_t; - using acc_data_t = int32_t; using out_data_t = int8_t; + using acc_data_t = int32_t; + using wei_data_t = int8_t; #endif std::vector in_lengths_host(4), wei_lengths_host(4), out_lengths_host(4); @@ -164,8 +182,8 @@ int main(int argc, char* argv[]) } Tensor in(in_lengths_host); - Tensor wei_device(wei_lengths_host); - Tensor wei_host(wei_lengths_host); + Tensor wei_device(wei_lengths_host); + Tensor wei_host(wei_lengths_host); Tensor out(out_lengths_host); std::cout << "layout: " << layout << std::endl; @@ -231,6 +249,26 @@ int main(int argc, char* argv[]) in_right_pads_dev); }; + auto f_make_for_device_nhwc = [&]() { + const auto in_lengths_dev = make_tuple(N, Hi, Wi, C); + const auto wei_lengths_dev = make_tuple(K, Y, X, C); + const auto out_lengths_dev = make_tuple(N, Ho, Wo, K); + const auto conv_strides_dev = make_tuple(conv_stride_h, conv_stride_w); + const auto conv_dilations_dev = make_tuple(conv_dilation_h, conv_dilation_w); + const auto in_left_pads_dev = make_tuple(in_left_pad_h, in_left_pad_w); + const auto in_right_pads_dev = make_tuple(in_right_pad_h, in_right_pad_w); + + return make_tuple(in_lengths_dev, + wei_lengths_dev, + out_lengths_dev, + conv_strides_dev, + conv_dilations_dev, + in_left_pads_dev, + in_right_pads_dev); + }; + + // set zero to wei_device + wei_device.GenerateTensorValue(GeneratorTensor_0{}, num_thread); #if USE_CONV_WRW_V4R4R2_XDL_NCHW if(algo == ConvBackwardWeightAlgo::V4R4R2XDLNCHW) { @@ -242,6 +280,7 @@ int main(int argc, char* argv[]) const auto tmp = f_make_for_device_nchw(); device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( tmp[I0], @@ -258,6 +297,121 @@ int main(int argc, char* argv[]) } #endif +#if USE_CONV_WRW_V4R4R4_XDL_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R4XDLNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( + tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R2_XDL_ATOMIC_NCHW + if(algo == ConvBackwardWeightAlgo::V4R4R2XDLATOMICNCHW) + { + if(layout != ConvTensorLayout::NCHW) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nchw(); + + device_convolution_backward_weight_implicit_gemm_v4r4r2_xdlops_atomic_nchw_kcyx_nkhw< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R4_XDL_ATOMIC_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R4XDLATOMICNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r4_xdlops_atomic_nhwc_kyxc_nhwk< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + +#if USE_CONV_WRW_V4R4R5_XDL_ATOMIC_NHWC + if(algo == ConvBackwardWeightAlgo::V4R4R5XDLATOMICNHWC) + { + if(layout != ConvTensorLayout::NHWC) + { + throw std::runtime_error("wrong! layout"); + } + + const auto tmp = f_make_for_device_nhwc(); + + device_convolution_backward_weight_implicit_gemm_v4r4r5_xdlops_atomic_nhwc_kyxc_nhwk< + in_data_t, + wei_data_t, + acc_data_t, + out_data_t>(tmp[I0], + tmp[I1], + tmp[I2], + tmp[I3], + tmp[I4], + tmp[I5], + tmp[I6], + in, + wei_device, + out, + desired_grid_size, + nrepeat); + } +#endif + if(do_verification) { host_direct_convolution_backward_weights(out, diff --git a/host/host_tensor/include/device.hpp b/host/host_tensor/include/device.hpp index 9b66f24f7a..cb1a6effa1 100644 --- a/host/host_tensor/include/device.hpp +++ b/host/host_tensor/include/device.hpp @@ -2,6 +2,7 @@ #define DEVICE_HPP #include +#include #include #include #include "hip/hip_runtime.h" @@ -80,5 +81,4 @@ float launch_and_time_kernel( return timer.GetElapsedTime() / nrepeat; } - #endif diff --git a/host/host_tensor/include/host_tensor_generator.hpp b/host/host_tensor/include/host_tensor_generator.hpp index 7c09843d01..b0d53995ed 100644 --- a/host/host_tensor/include/host_tensor_generator.hpp +++ b/host/host_tensor/include/host_tensor_generator.hpp @@ -15,6 +15,17 @@ struct GeneratorTensor_1 } }; +struct GeneratorTensor_0 +{ + int value = 0; + + template + float operator()(Is...) + { + return value; + } +}; + struct GeneratorTensor_2 { int min_value = 0; From c3018794b4b0c22187ddecc3547cf002afdd8c45 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Tue, 19 Oct 2021 18:43:10 -0500 Subject: [PATCH 14/15] bug fix (#39) --- .../tensor_operation/threadwise_tensor_slice_transfer.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp index 157828bf0f..7e3f6b3489 100644 --- a/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp +++ b/composable_kernel/include/tensor_operation/threadwise_tensor_slice_transfer.hpp @@ -397,7 +397,7 @@ struct ThreadwiseTensorSliceTransfer_v2 "wrong! SrcDesc need to known at compile-time"); } - __device__ void SetDstSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) + __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) { src_coord_ = make_tensor_coordinate(src_desc, src_slice_origin_idx); } From d5297abae9b284097dc637976c5d39ec9dc3e700 Mon Sep 17 00:00:00 2001 From: Chao Liu Date: Thu, 21 Oct 2021 16:42:24 -0500 Subject: [PATCH 15/15] fix bug in gridwise gemm xdlops v2r3 (#45) --- .../gridwise_gemm_xdlops_v2r3.hpp | 122 ++++++++------- ...icit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp | 4 +- .../include/driver_gemm_xdlops_v2r3.hpp | 140 +++++++++++++----- 3 files changed, 177 insertions(+), 89 deletions(-) diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index e3b0054bec..86e047c965 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -19,7 +19,8 @@ template + typename CBlockClusterAdaptor, + bool HasMainKBlockLoop> __global__ void #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU) @@ -37,14 +38,14 @@ __global__ void __shared__ FloatAB p_shared_block[shared_block_size]; - GridwiseGemm::Run(p_a_grid, - p_b_grid, - p_c_grid, - p_shared_block, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + GridwiseGemm::template Run(p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER template (p_a_grid, + p_b_grid, + p_c_grid, + p_shared_block, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); } #endif @@ -102,7 +103,7 @@ template {}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -173,13 +174,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -217,7 +218,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 K1 == b_k0_n_k1_grid_desc.GetLength(I2))) return false; - if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % KPerBlock == 0)) + if(!(M % MPerBlock == 0 && N % NPerBlock == 0 && K0 % K0PerBlock == 0)) return false; // check M01, N01 @@ -245,6 +246,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return grid_size; } + __host__ __device__ static constexpr bool CalculateHasMainK0BlockLoop(index_t K0) + { + const bool has_main_k0_block_loop = (K0 / K0PerBlock) > 1; + + return has_main_k0_block_loop; + } + __host__ __device__ static constexpr auto MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) { @@ -255,13 +263,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 if constexpr(ABlockLdsExtraM) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -270,13 +278,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -334,6 +342,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{}, 1, 1)); + template __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, FloatC* __restrict__ p_c_grid, @@ -371,13 +380,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 if constexpr(ABlockLdsExtraM) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -386,13 +395,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 if constexpr(BBlockLdsExtraN) { return make_naive_tensor_descriptor( - make_tuple(Number{}, Number{}, K1), + make_tuple(Number{}, Number{}, K1), make_tuple(Number{} * K1, K1, I1)); } else { return make_naive_tensor_descriptor_aligned( - make_tuple(Number{}, Number{}, K1), max_lds_align); + make_tuple(Number{}, Number{}, K1), max_lds_align); } }(); @@ -400,7 +409,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto a_blockwise_copy = BlockwiseTensorSliceTransfer_v4, + Sequence, ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, @@ -426,7 +435,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 auto b_blockwise_copy = BlockwiseTensorSliceTransfer_v4, + Sequence, BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, @@ -450,8 +459,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // GEMM definition // c_mtx += transpose(a_mtx) * b_mtx - // a_mtx[KPerBlock, MPerBlock] is in LDS - // b_mtx[KPerBlock, NPerBlock] is in LDS + // a_mtx[K0PerBlock, MPerBlock] is in LDS + // b_mtx[K0PerBlock, NPerBlock] is in LDS // c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in // register // sanity check @@ -477,8 +486,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAB* p_a_block = p_shared_block; FloatAB* p_b_block = p_shared_block + a_block_space_size; - constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); - constexpr auto b_block_slice_copy_step = make_multi_index(KPerBlock, 0, 0); + constexpr auto a_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); + constexpr auto b_block_slice_copy_step = make_multi_index(K0PerBlock, 0, 0); // hack to control index calculation when iterating over A and B matrix for threadwise copy constexpr auto a_k0_m_k1_grid_step_hacks = AGridStepHacks{}; @@ -504,32 +513,37 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } // main body - index_t k_block_data_begin = 0; + index_t k0_block_data_begin = 0; - do + if constexpr(HasMainKBlockLoop) { - a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, - a_block_slice_copy_step, - a_k0_m_k1_grid_move_slice_window_step_hack); - b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, - b_block_slice_copy_step, - b_k0_n_k1_grid_move_slice_window_step_hack); + do + { + a_blockwise_copy.MoveSrcSliceWindow(a_k0_m_k1_grid_desc, + a_block_slice_copy_step, + a_k0_m_k1_grid_move_slice_window_step_hack); + b_blockwise_copy.MoveSrcSliceWindow(b_k0_n_k1_grid_desc, + b_block_slice_copy_step, + b_k0_n_k1_grid_move_slice_window_step_hack); - a_blockwise_copy.RunRead(a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); + a_blockwise_copy.RunRead( + a_k0_m_k1_grid_desc, a_grid_buf, a_k0_m_k1_grid_step_hacks); - block_sync_lds(); + block_sync_lds(); - b_blockwise_copy.RunRead(b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); + b_blockwise_copy.RunRead( + b_k0_n_k1_grid_desc, b_grid_buf, b_k0_n_k1_grid_step_hacks); - blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); + blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); - block_sync_lds(); + block_sync_lds(); - a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); - b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); + a_blockwise_copy.RunWrite(a_k0_m_k1_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_k0_n_k1_block_desc, b_block_buf); - k_block_data_begin += KPerBlock; - } while(k_block_data_begin < (K0 - KPerBlock)); + k0_block_data_begin += K0PerBlock; + } while(k0_block_data_begin < (K0 - K0PerBlock)); + } // tail { diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 01e5c57ab4..1b23aa1a8c 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -160,7 +160,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 0 +#elif 1 // [M, N, K0, K1] = [128, 256, 4, 8], C = 128, for fp16 constexpr index_t BlockSize = 256; @@ -188,7 +188,7 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; constexpr index_t GemmCThreadTransferDstScalarPerVector = 1; -#elif 1 +#elif 0 // [M, N, K0, K1] = [128, 128, 4, 8], C = 64, for fp16 constexpr index_t BlockSize = 256; diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index 91ea24f947..4ccfbaab0a 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -148,28 +148,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, const index_t grid_size = GridwiseGemm::CalculateGridSize(c_m_n_grid_desc); - const auto kernel = kernel_gemm_xdlops_v2r3, - remove_reference_t, - remove_reference_t, - remove_reference_t>; + const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); + + const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); + + float ave_time = 0; #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE - float ave_time = launch_and_time_kernel(kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - a_k0_m_k1_grid_desc, - b_k0_n_k1_grid_desc, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_block_cluster_adaptor); + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel(kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + a_k0_m_k1_grid_desc, + b_k0_n_k1_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_block_cluster_adaptor); + } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); @@ -181,20 +214,61 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); - float ave_time = launch_and_time_kernel( - kernel, - nrepeat, - dim3(grid_size), - dim3(BlockSize), - 0, - p_a_grid, - p_b_grid, - p_c_grid, - cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space( - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + if(has_main_k0_block_loop) + { + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + true>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } + else + { + const auto kernel = kernel_gemm_xdlops_v2r3, + remove_reference_t, + remove_reference_t, + remove_reference_t, + false>; + + ave_time = launch_and_time_kernel( + kernel, + nrepeat, + dim3(grid_size), + dim3(BlockSize), + 0, + p_a_grid, + p_b_grid, + p_c_grid, + cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); + } +} #endif return ave_time; }