diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index ee6a0b7427..0c381de0d8 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -9,16 +9,15 @@ namespace ck { template -struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 +struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { - - using CIndex = MultiIndex<2>; - static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; @@ -26,111 +25,165 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 static constexpr index_t WaveSize = 64; - static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); + static constexpr index_t MPerBlock = AK0MK1BlockDesc{}.GetLength(I1); + static constexpr index_t NPerBlock = BK0NK1BlockDesc{}.GetLength(I1); - static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); + static constexpr index_t K0 = BK0NK1BlockDesc{}.GetLength(I0); + static constexpr index_t KPerBlock = K0; - static constexpr auto xdlops_gemm = XdlopsGemm{}; + static constexpr auto xdlops_gemm = XdlopsGemm{}; - static constexpr index_t MWaves = M1 / MPerWave; - static constexpr index_t NWaves = N1 / NPerWave; + static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); + static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); - static constexpr index_t MRepeat = M0; - static constexpr index_t NRepeat = N0; + __device__ static auto GetWaveIdx() + { + const index_t thread_id = get_thread_local_1d_id(); - __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } + const auto threadid_to_wave_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(MWaves, NWaves, WaveSize))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); - __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } - - __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } + return threadid_to_wave_idx_adaptor.CalculateBottomIndex(make_multi_index(thread_id)); + } __device__ static auto CalculateAThreadOriginDataIndex() { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; + const auto wave_idx = GetWaveIdx(); - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, m_offset, 0); - } - else - { - const index_t m_offset = waveId_m * MPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, m_offset, 0); - } + const auto waveId_m = wave_idx[I0]; + + const auto xdlops_a_idx = xdlops_gemm.CalculateAThreadOriginDataIndex(); + + return make_tuple(xdlops_a_idx[I0], 0, waveId_m, xdlops_a_idx[I1], 0); } __device__ static auto CalculateBThreadOriginDataIndex() { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_n = waveId % NWaves; + const auto wave_idx = GetWaveIdx(); - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, n_offset, 0); - } - else - { - const index_t n_offset = waveId_n * NPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, n_offset, 0); - } + const auto waveId_n = wave_idx[I1]; + + const auto xdlops_b_idx = xdlops_gemm.CalculateBThreadOriginDataIndex(); + + return make_tuple(xdlops_b_idx[I0], 0, waveId_n, xdlops_b_idx[I1], 0); } template - __device__ static CIndex + __device__ static auto CalculateCThreadOriginDataIndex(Number, Number, Number, Number) { + const auto wave_idx = GetWaveIdx(); - const index_t waveId = get_thread_local_1d_id() / WaveSize; + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); + const auto blk_idx = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; + constexpr auto mrepeat_mwave_mperxdl_to_m_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); - const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; - const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; + constexpr auto nrepeat_nwave_nperxdl_to_n_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}), + make_tuple(Sequence<0, 1, 2>{})); - return CIndex{m_offset, n_offset}; + const index_t c_thread_m = mrepeat_mwave_mperxdl_to_m_adaptor.CalculateBottomIndex( + make_tuple(m0, waveId_m, blk_idx[I0]))[I0]; + const index_t c_thread_n = nrepeat_nwave_nperxdl_to_n_adaptor.CalculateBottomIndex( + make_tuple(n0, waveId_n, blk_idx[I1]))[I0]; + + return make_tuple(c_thread_m, c_thread_n); } - __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1() - : a_thread_copy_{CalculateAThreadOriginDataIndex()}, - b_thread_copy_{CalculateBThreadOriginDataIndex()} + __host__ __device__ BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1() { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), + static_assert(AK0MK1BlockDesc::IsKnownAtCompileTime() && + BK0NK1BlockDesc::IsKnownAtCompileTime(), "wrong! Desc should be known at compile-time"); - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); + static_assert(AK0MK1BlockDesc{}.GetLength(I0) == BK0NK1BlockDesc{}.GetLength(I0), + "wrong! K0 dimension not consistent"); - static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), + static_assert(AK0MK1BlockDesc{}.GetLength(I2) == BK0NK1BlockDesc{}.GetLength(I2), "wrong! K1 dimension not consistent"); static_assert(BlockSize == MWaves * NWaves * WaveSize, "BlockSize != MWaves * NWaves * WaveSize\n"); - static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - - static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); + static_assert(MPerBlock % (MPerXDL * MRepeat) == 0 && NPerBlock % (NPerXDL * NRepeat) == 0, + "wrong!"); } + __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2ThreadDescriptor() + { + constexpr auto c_m0_m1_m2_n_tblk_lens = xdlops_gemm.GetCM0M1M2NThreadBlkLengths(); + + constexpr auto M0 = c_m0_m1_m2_n_tblk_lens[I0]; + constexpr auto M1 = c_m0_m1_m2_n_tblk_lens[I1]; + constexpr auto M2 = c_m0_m1_m2_n_tblk_lens[I2]; + constexpr auto N = c_m0_m1_m2_n_tblk_lens[I3]; + + return make_naive_tensor_descriptor_packed(make_tuple(I1, I1, I1, I1, M0, M1, M2, N)); + } + + __host__ __device__ static constexpr auto GetCM0N0M1N1M2M3M4N2BlockDescriptor() + { + constexpr auto c_m0_n0_m1_n1_m2_n2_block_desc = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, + Number{}, + Number{}, + Number{}, + Number{}, + Number{})); + + return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_block_desc); + } + + template + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + { + const auto c_m0_n0_m1_n1_m2_n2_grid_desc = transform_tensor_descriptor( + c_m_n_grid_desc, + make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, MPerXDL)), + make_unmerge_transform(make_tuple(NRepeat, NWaves, NPerXDL))), + make_tuple(Sequence<0>{}, Sequence<1>{}), + make_tuple(Sequence<0, 2, 4>{}, Sequence<1, 3, 5>{})); + + return xdlops_gemm.MakeCM0N0M1N1M2M3M4N2Descriptor(c_m0_n0_m1_n1_m2_n2_grid_desc); + } + + __host__ __device__ static constexpr auto MakeAK0M0M1M2K1BlockDescriptor() + { + return transform_tensor_descriptor( + AK0MK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + __host__ __device__ static constexpr auto MakeBK0N0N1N2K1BlockDescriptor() + { + return transform_tensor_descriptor( + BK0NK1BlockDesc{}, + make_tuple(make_pass_through_transform(Number{}), + make_unmerge_transform( + make_tuple(Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), + make_tuple(Sequence<0>{}, Sequence<1, 2, 3>{}, Sequence<4>{})); + } + + static constexpr auto a_k0_m0_m1_m2_k1_block_desc = MakeAK0M0M1M2K1BlockDescriptor(); + static constexpr auto b_k0_n0_n1_n2_k1_block_desc = MakeBK0N0N1N2K1BlockDescriptor(); + template __device__ void Run(const ABlockBuffer& a_block_buf, const BBlockBuffer& b_block_buf, @@ -141,49 +194,48 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); + vector_type a_thread_vec; - vector_type a_thread_vec; + vector_type b_thread_vec; - vector_type b_thread_vec; - - static_for<0, KPerBlock, xdlops_gemm.KPerXdlops>{}([&](auto k) { + static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) { // read A - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0, I0), + a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, + make_tuple(k0, I0, I0, I0, I0), a_block_buf, a_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, I0, I0, I0, I0), a_thread_buf); // read B - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0, I0), + b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, + make_tuple(k0, I0, I0, I0, I0), b_block_buf, b_thread_desc_, - make_tuple(I0, I0, I0, I0), + make_tuple(I0, I0, I0, I0, I0), b_thread_buf); - using mfma_input_type = - typename vector_type::type; - - static_for<0, a_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { - a_thread_vec.template AsType()(Number{}) = a_thread_buf[Number{}]; - }); - - static_for<0, b_thread_desc_.GetElementSpaceSize(), 1>{}([&](auto i) { - b_thread_vec.template AsType()(Number{}) = b_thread_buf[Number{}]; - }); + using mfma_input_type = typename vector_type::type; static_for<0, MRepeat, 1>{}([&](auto m0) { static_for<0, NRepeat, 1>{}([&](auto n0) { - xdlops_gemm.template Run(a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf); + static_for<0, K1, 1>{}([&](auto i) { + a_thread_vec.template AsType()(i) = a_thread_buf + [Number{}]; + }); + + static_for<0, K1, 1>{}([&](auto i) { + b_thread_vec.template AsType()(i) = b_thread_buf + [Number{}]; + }); + + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + + xdlops_gemm.template Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf); }); }); }); @@ -191,333 +243,38 @@ struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1 private: // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, I1, Number{})); // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); + static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(I1, Number{}, I1, I1, Number{})); - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{}, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, MRepeat, 1, 1, K1>, + Sequence<0, 1, 2, 3, 4>, + 4, K1, 1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, + Sequence<1, NRepeat, 1, 1, K1>, + Sequence<0, 1, 2, 3, 4>, + 4, K1, 1>; - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; -}; - -template -struct BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline -{ - - using CIndex = MultiIndex<2>; - - static constexpr auto I0 = Number<0>{}; - static constexpr auto I1 = Number<1>{}; - static constexpr auto I2 = Number<2>{}; - static constexpr auto I3 = Number<3>{}; - - static constexpr auto xdlops_gemm = XdlopsGemm{}; - - static constexpr index_t WaveSize = 64; - - static constexpr index_t M0 = ABlockDesc{}.GetLength(I1); - static constexpr index_t M1 = ABlockDesc{}.GetLength(I2); - - static constexpr index_t N0 = BBlockDesc{}.GetLength(I1); - static constexpr index_t N1 = BBlockDesc{}.GetLength(I2); - - static constexpr index_t MWaves = M1 / MPerWave; - static constexpr index_t NWaves = N1 / NPerWave; - - static constexpr index_t MRepeat = M0; - static constexpr index_t NRepeat = N0; - - __device__ constexpr auto GetCLayout() const { return xdlops_gemm.GetCLayout(); } - - __device__ constexpr auto GetNumBlks() const { return xdlops_gemm.GetCLayout().GetNumBlks(); } - - __device__ constexpr auto GetBlkSize() const { return xdlops_gemm.GetCLayout().GetBlkSize(); } - - __device__ static auto CalculateAThreadOriginDataIndex() - { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_m = waveId / NWaves; - - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t m_offset = waveId_m * MPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, m_offset, 0); - } - else - { - const index_t m_offset = waveId_m * MPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, m_offset, 0); - } - } - - __device__ static auto CalculateBThreadOriginDataIndex() - { - const index_t thread_id = get_thread_local_1d_id(); - const index_t waveId = thread_id / WaveSize; - const index_t laneId = thread_id % WaveSize; - const index_t waveId_n = waveId % NWaves; - - if constexpr(xdlops_gemm.IsKReduction) - { - const index_t n_offset = waveId_n * NPerWave + xdlops_gemm.GetBlkTd(laneId); - const index_t k_offset = xdlops_gemm.GetBlkId(laneId); - return make_tuple(k_offset, 0, n_offset, 0); - } - else - { - const index_t n_offset = waveId_n * NPerWave + laneId; - const index_t k_offset = 0; - return make_tuple(k_offset, 0, n_offset, 0); - } - } - - template - __device__ static CIndex - CalculateCThreadOriginDataIndex(Number, Number, Number, Number) - { - - const index_t waveId = get_thread_local_1d_id() / WaveSize; - - const auto thread_mtx_on_blk = xdlops_gemm.GetBeginOfThreadBlk(xdlops_i, blk_i); - - const index_t waveId_m = waveId / NWaves; - const index_t waveId_n = waveId % NWaves; - - const index_t m_offset = m0 * M1 + waveId_m * MPerWave + thread_mtx_on_blk[I0]; - const index_t n_offset = n0 * N1 + waveId_n * NPerWave + thread_mtx_on_blk[I1]; - - return CIndex{m_offset, n_offset}; - } - - __device__ BlockwiseGemmXdlops_km_kn_m0m1m2n_v1_2x2pipeline() - : a_thread_copy_{CalculateAThreadOriginDataIndex()}, - b_thread_copy_{CalculateBThreadOriginDataIndex()} - { - static_assert(ABlockDesc::IsKnownAtCompileTime() && BBlockDesc::IsKnownAtCompileTime(), - "wrong! Desc should be known at compile-time"); - - static_assert(ABlockDesc{}.GetLength(I0) == BBlockDesc{}.GetLength(I0), - "wrong! K dimension not consistent"); - - static_assert(ABlockDesc{}.GetLength(I3) == BBlockDesc{}.GetLength(I3), - "wrong! K1 dimension not consistent"); - - static_assert(BlockSize == MWaves * NWaves * WaveSize, - "BlockSize != MWaves * NWaves * WaveSize\n"); - - static_assert(K1 == BBlockDesc{}.GetLength(I3), "K1 is wrong!"); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - static_assert(KPerBlock % xdlops_gemm.KPerXdlops == 0, "KPerBlock is wrong!"); - - static_assert(K1 % xdlops_gemm.mfma_type.k_base == 0, "K1 is wrong!"); - } - - template - __device__ void Run(const ABlockBuffer& a_block_buf, - const BBlockBuffer& b_block_buf, - CThreadBuffer& c_thread_buf) const - { - auto a_thread_buf = make_static_buffer( - a_thread_desc_.GetElementSpaceSize()); - auto b_thread_buf = make_static_buffer( - b_thread_desc_.GetElementSpaceSize()); - - constexpr index_t KPerBlock = ABlockDesc{}.GetLength(I0); - - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(I0, I1, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(I0, I1, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - static_for{}([&](auto k) { - // read A_sub_0 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I0, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I0, I0, I0), - a_thread_buf); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // read B_sub_0 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0), - b_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // read B_sub_1 - b_thread_copy_.Run(BBlockDesc{}, - make_tuple(k, I1, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I1, I0, I0), - b_thread_buf); - - // read A_sub_1 - a_thread_copy_.Run(ABlockDesc{}, - make_tuple(k, I1, I0, I0), - a_block_buf, - a_thread_desc_, - make_tuple(I0, I1, I0, I0), - a_thread_buf); - - // C_sub_00 += transpose(A_sub_0) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_01 += transpose(A_sub_0) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - }); - - // C_sub_10 += transpose(A_sub_1) * B_sub_0 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - - // C_sub_11 += transpose(A_sub_1) * B_sub_1 - xdlops_gemm.template Run(a_thread_buf, b_thread_buf, c_thread_buf); - } - - private: - // A[K, M] - static constexpr auto a_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - - // B[K, N] - static constexpr auto b_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(I1, Number{}, I1, Number{})); - - static constexpr auto c_thread_desc_ = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; - - using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, - Sequence<0, 1, 2, 3>, - 3, - 1, // K1, - 1>; - - AThreadCopy a_thread_copy_; - BThreadCopy b_thread_copy_; + AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; + BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; }; } // namespace ck diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 207f73072f..51af11fefc 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -18,7 +18,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -29,7 +29,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const AK0MK1GridDesc a_k0_m_k1_grid_desc, const BK0NK1GridDesc b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc c_m0_m1_m2_n_grid_desc, const CBlockClusterAdaptor c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -43,7 +43,7 @@ __global__ void p_shared_block, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); } #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER @@ -52,7 +52,7 @@ template __global__ void #if CK_USE_LAUNCH_BOUNDS @@ -63,7 +63,7 @@ __global__ void FloatC* __restrict__ p_c_grid, const void CONSTANT* p_a_k0_m_k1_grid_desc, const void CONSTANT* p_b_k0_n_k1_grid_desc, - const void CONSTANT* p_c_m0_m1_m2_n_grid_desc, + const void CONSTANT* p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const void CONSTANT* p_c_block_cluster_adaptor) { constexpr index_t shared_block_size = @@ -73,8 +73,9 @@ __global__ void cast_pointer_to_generic_address_space(p_a_k0_m_k1_grid_desc)); const auto b_k0_n_k1_grid_desc = *reinterpret_cast( cast_pointer_to_generic_address_space(p_b_k0_n_k1_grid_desc)); - const auto c_m0_m1_m2_n_grid_desc = *reinterpret_cast( - cast_pointer_to_generic_address_space(p_c_m0_m1_m2_n_grid_desc)); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + *reinterpret_cast( + cast_pointer_to_generic_address_space(p_c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc)); const auto c_block_cluster_adaptor = *reinterpret_cast( cast_pointer_to_generic_address_space(p_c_block_cluster_adaptor)); @@ -86,7 +87,7 @@ __global__ void p_shared_block, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); } #endif @@ -138,6 +139,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; + static constexpr auto I6 = Number<6>{}; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -201,29 +205,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 } __host__ __device__ static constexpr auto - MakeCM0M1M2NGridDescriptor(const CMNGridDesc& c_m_n_grid_desc) + MakeCM0N0M1N1M2M3M4N2GridDescriptor(const CMNGridDesc& c_m_n_grid_desc) { - constexpr auto xdlops_gemm = XdlopsGemm{}; + constexpr auto max_lds_align = K1; - constexpr auto CLayout = xdlops_gemm.GetCLayout(); + constexpr auto a_k0_m_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); - constexpr auto M0 = Number{}; - constexpr auto M1 = Number{}; - constexpr auto M2 = Number{}; + constexpr auto b_k0_n_k1_block_desc = make_naive_tensor_descriptor_aligned( + make_tuple(Number{}, Number{}, K1), max_lds_align); - constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); - constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); + using BlockwiseGemm = + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1; - constexpr auto N1 = Number{}; - - const auto c_m0_m1_m2_n_grid_desc = transform_tensor_descriptor( - c_m_n_grid_desc, - make_tuple(make_unmerge_transform(make_tuple(MRepeat, MWaves, M0, M1, M2)), - make_unmerge_transform(make_tuple(NRepeat, NWaves, N1))), - make_tuple(Sequence<0>{}, Sequence<1>{}), - make_tuple(Sequence<0, 2, 4, 5, 6>{}, Sequence<1, 3, 7>{})); - - return c_m0_m1_m2_n_grid_desc; + return BlockwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); } __host__ __device__ static constexpr auto @@ -253,8 +256,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 return c_blockid_to_m0_n0_block_cluster_adaptor; } - using CM0M1M2NGridDesc = decltype(MakeCM0M1M2NGridDescriptor(CMNGridDesc{})); - using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); + using CM0N0M1N1M2M3M4N2GridDesc = decltype(MakeCM0N0M1N1M2M3M4N2GridDescriptor(CMNGridDesc{})); + using CBlockClusterAdaptor = decltype(MakeCBlockClusterAdaptor(CMNGridDesc{})); __device__ static void Run(const FloatAB* __restrict__ p_a_grid, const FloatAB* __restrict__ p_b_grid, @@ -262,7 +265,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 FloatAB* __restrict__ p_shared_block, const AK0MK1GridDesc& a_k0_m_k1_grid_desc, const BK0NK1GridDesc& b_k0_n_k1_grid_desc, - const CM0M1M2NGridDesc& c_m0_m1_m2_n_grid_desc, + const CM0N0M1N1M2M3M4N2GridDesc& c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, const CBlockClusterAdaptor& c_block_cluster_adaptor) { const auto a_grid_buf = make_dynamic_buffer( @@ -270,7 +273,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const auto b_grid_buf = make_dynamic_buffer( p_b_grid, b_k0_n_k1_grid_desc.GetElementSpaceSize()); auto c_grid_buf = make_dynamic_buffer( - p_c_grid, c_m0_m1_m2_n_grid_desc.GetElementSpaceSize()); + p_c_grid, c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc.GetElementSpaceSize()); const auto K0 = a_k0_m_k1_grid_desc.GetLength(I0); @@ -358,50 +361,26 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 // register // sanity check - static_assert(MPerBlock % (MPerWave * MRepeat) == 0 && - NPerBlock % (NPerWave * NRepeat) == 0, - "wrong!"); - - constexpr auto a_k0_m0_m1_k1_block_desc = transform_tensor_descriptor( - a_k0_m_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - - constexpr auto b_k0_n0_n1_k1_block_desc = transform_tensor_descriptor( - b_k0_n_k1_block_desc, - make_tuple(make_pass_through_transform(Number{}), - make_unmerge_transform( - make_tuple(Number{}, Number{})), - make_pass_through_transform(K1)), - make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), - make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3>{})); - const auto blockwise_gemm = - BlockwiseGemmXdlops_km_kn_m0m1m2n_v1{}; - - constexpr auto CLayout = blockwise_gemm.GetCLayout(); - - constexpr index_t BlkSize = CLayout.GetBlkSize(); - constexpr index_t NumBlks = CLayout.GetNumBlks(); - constexpr index_t NumXdlops = CLayout.GetNumXdlops(); - - static_assert(NumBlks == 1 && NumXdlops == 1, "K Reduction Mfma only"); + BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; constexpr auto c_mr_nr_blk_desc = make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor(); + constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize(); + StaticBuffer, + vector_type, c_mr_nr_blk_desc.GetElementSpaceSize(), true> c_thread_buf; @@ -474,41 +453,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 blockwise_gemm.Run(a_block_buf, b_block_buf, c_thread_buf); } -#if 0 // output: register to global memory { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = + blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); - constexpr index_t N0 = CLayout.N1(); - constexpr index_t N1 = CLayout.N0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, - Number{}, - Number<1>{}, - Number<1>{}, - Number{}, - Number<1>{}, - Number{}, - Number<1>{})); - - StaticBuffer - c_blk_buf_; - - static_for<0, MRepeat, 1>{}([&](auto mr_i) { - static_for<0, NRepeat, 1>{}([&](auto nr_i) { - constexpr auto blk_off = - c_mr_nr_blk_desc.CalculateOffset(make_tuple(mr_i, nr_i)); - - static_for<0, BlkSize, 1>{}([&](auto j) { - c_blk_buf_(Number{}) = - c_thread_buf[Number{}] - .template AsType()[Number{}]; - }); - }); - }); + constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); + constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); + constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -521,145 +473,96 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 const index_t n_thread_data_on_grid = n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; - - constexpr index_t MWaves = MPerBlock / (MPerWave * MRepeat); - constexpr index_t NWaves = NPerBlock / (NPerWave * NRepeat); - - ThreadwiseTensorSliceTransfer_v1r3< - FloatC, - FloatC, - decltype(c_m0_m1_m2_n_thread_desc), - decltype(c_m0_m1_m2_n_grid_desc), - Sequence, - CThreadTransferSrcDstAccessOrder, - CThreadTransferSrcDstVectorDim, - CThreadTransferDstScalarPerVector, - CGlobalMemoryDataOperation, - 1, - true>{ - c_m0_m1_m2_n_grid_desc, - make_multi_index(m_thread_data_on_grid / (M2 * M1 * M0 * MWaves), - n_thread_data_on_grid / (N1 * NWaves), - m_thread_data_on_grid % (M2 * M1 * M0 * MWaves) / (M2 * M1 * M0), - n_thread_data_on_grid % (N1 * NWaves) / N1, - m_thread_data_on_grid % (M2 * M1 * M0) / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, - n_thread_data_on_grid % N1)} - .Run(c_m0_m1_m2_n_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_blk_buf_, - c_m0_m1_m2_n_grid_desc, - c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); - } -#else - { - constexpr index_t M0 = CLayout.M1(); - constexpr index_t M1 = CLayout.N1(); - constexpr index_t M2 = CLayout.M0(); - - constexpr auto c_m0_m1_m2_n_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(I1, I1, I1, I1, Number{}, Number<1>{}, Number{}, Number<1>{})); - - // calculate origin of thread output tensor on global memory - // blockwise GEMM c matrix starting index - const auto c_thread_mtx_on_block = - blockwise_gemm.CalculateCThreadOriginDataIndex(I0, I0, I0, I0); - - const index_t m_thread_data_on_grid = - m_block_data_idx_on_grid + c_thread_mtx_on_block[I0]; - - const index_t n_thread_data_on_grid = - n_block_data_idx_on_grid + c_thread_mtx_on_block[I1]; - - constexpr auto c_m0_m1_m2_n_grid_tensor_step_hacks = CGridStepHacks{}; + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; auto c_thread_copy = ThreadwiseTensorSliceTransfer_v1r3, + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc), + decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc), + Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, true>{ - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, make_multi_index(0, 0, 0, 0, - m_thread_data_on_grid / (M2 * M1), - m_thread_data_on_grid % (M2 * M1) / M2, - m_thread_data_on_grid % M2, + m_thread_data_on_grid / (M3 * M4), + m_thread_data_on_grid % (M3 * M4) / M4, + m_thread_data_on_grid % M4, n_thread_data_on_grid)}; auto init_copy = [&](auto c_thread_idx_) { constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); return c_thread_idx_; }; auto mrepeat_plus_copy = [&](auto c_thread_idx_) { constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + mrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto nrepeat_plus_copy = [&](auto c_thread_idx_) { constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + nrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto mrepeat_minus_copy = [&](auto c_thread_idx_) { constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, mrepeat_step_plus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + mrepeat_step_plus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; auto nrepeat_minus_copy = [&](auto c_thread_idx_) { constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_m1_m2_n_grid_desc, nrepeat_step_minus); + c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + nrepeat_step_minus); constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_m1_m2_n_thread_desc, + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), c_thread_buf[Number{}].template AsType(), - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_grid_buf, - c_m0_m1_m2_n_grid_tensor_step_hacks); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); }; static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or @@ -791,7 +694,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 init_copy(make_tuple(I0, I0)); } } -#endif } }; // namespace ck diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index affe096ace..56581c024b 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -7,21 +7,18 @@ namespace ck { -enum struct mfma_instr +enum struct MfmaInstr { - /// fp32 mfma_f32_32x32x1xf32 = 0, mfma_f32_16x16x1xf32, mfma_f32_4x4x1xf32, mfma_f32_32x32x2xf32, // k reduction mfma_f32_16x16x4xf32, // k reduction - /// fp16 mfma_f32_32x32x4f16, mfma_f32_16x16x4f16, mfma_f32_4x4x4f16, mfma_f32_32x32x8f16, // k reduction mfma_f32_16x16x16f16, // k reduction - /// bfp16 mfma_f32_32x32x2bf16, mfma_f32_16x16x2bf16, mfma_f32_4x4x2bf16, @@ -29,25 +26,23 @@ enum struct mfma_instr mfma_f32_16x16x8bf16, // k reduction }; -template -struct mfma_info; +template +struct mfma_type; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 1; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 2; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 4; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 1; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template // treat 4x4x1 as a single-blk 4x64 mfma template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 1; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 1; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 1; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 4; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 8; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 16; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 4; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 4; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 4; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 4; + static constexpr bool is_k_reduction = false; template #if 0 template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 2; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 2; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 2; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 4; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 32; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 32; - static constexpr index_t n = 32; - static constexpr index_t k = 4; - static constexpr index_t cycles = 64; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 4; + static constexpr index_t num_regs_per_blk = 16; + static constexpr index_t num_threads_per_blk = 32; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 2; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 32; + static constexpr index_t n_per_blk = 32; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 8; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = true; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 16; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = wave_size / num_threads_blk; - static constexpr index_t num_output_blks = 4; - static constexpr index_t num_regs_xdlops = num_regs_blk * num_output_blks; - static constexpr index_t m = 16; - static constexpr index_t n = 16; - static constexpr index_t k = 2; - static constexpr index_t cycles = 32; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 16; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 4; + static constexpr index_t num_output_blks = 4; + static constexpr index_t m_per_blk = 16; + static constexpr index_t n_per_blk = 16; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; template <> -struct mfma_info +struct mfma_type { - static constexpr index_t group_size = 4; - static constexpr index_t num_groups_blk = 1; - static constexpr index_t num_regs_blk = group_size * num_groups_blk; - static constexpr index_t num_threads_blk = 64; - static constexpr index_t wave_size = 64; - static constexpr index_t num_input_blks = 1; - static constexpr index_t num_output_blks = 1; - static constexpr index_t num_regs_xdlops = 4; - static constexpr index_t m = 4; - static constexpr index_t n = 64; - static constexpr index_t k = 2; - static constexpr index_t cycles = 8; - static constexpr index_t k_base = 2; + static constexpr index_t group_size = 4; + static constexpr index_t num_groups_per_blk = 1; + static constexpr index_t num_regs_per_blk = 4; + static constexpr index_t num_threads_per_blk = 64; + static constexpr index_t wave_size = 64; + static constexpr index_t num_input_blks = 1; + static constexpr index_t num_output_blks = 1; + static constexpr index_t m_per_blk = 4; + static constexpr index_t n_per_blk = 64; + static constexpr index_t k_per_blk = 2; + static constexpr bool is_k_reduction = false; template }; #endif -template -struct xdlops_info +template +struct MfmaSelector { - static constexpr auto mfma_type = mfma_info{}; + template + static constexpr auto GetMfma(); - static constexpr index_t MPerXdlops = MPerXdlops_; - static constexpr index_t NPerXdlops = NPerXdlops_; + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x1xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x2xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4xf32; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_32x32x8f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x16f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_16x16x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + + template <> + static constexpr auto GetMfma() + { + return MfmaInstr::mfma_f32_4x4x4f16; + } + +#if 0 + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } + + template <> + static constexpr auto GetMfma() + { + return xdlops_info{}; + } +#endif + + static constexpr auto selected_mfma = mfma_type()>{}; + + __host__ __device__ static constexpr void mfma_check() + { + static_assert(selected_mfma.group_size * selected_mfma.num_groups_per_blk == + selected_mfma.num_regs_per_blk, + "wrong! num_regs_per_blk"); + + static_assert(selected_mfma.num_threads_per_blk == selected_mfma.n_per_blk, + "n_per_blk != num_threads_per_blk"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.num_input_blks == + selected_mfma.m_per_blk, + "m_per_blk != num_input_blks * num_regs_per_blk"); + + static_assert(selected_mfma.num_output_blks == selected_mfma.num_input_blks || + selected_mfma.num_output_blks == 1, + "incorrect num_output_blks"); + + static_assert(selected_mfma.num_regs_per_blk * selected_mfma.wave_size == + selected_mfma.m_per_blk * selected_mfma.n_per_blk, + "num_regs_per_blk incorrect"); + + static_assert(selected_mfma.is_k_reduction || + (selected_mfma.num_input_blks == selected_mfma.num_output_blks), + "is_k_reduction wrong!"); + } + + __host__ __device__ constexpr MfmaSelector() { mfma_check(); } static constexpr bool IsABroadcast() { @@ -505,186 +652,33 @@ struct xdlops_info return true; } - static constexpr bool IsKReduction() - { - return (mfma_type.num_output_blks == 1) && (mfma_type.num_input_blks > 1); - } - static constexpr index_t GetKPerXdlops() { - return IsKReduction() ? mfma_type.num_input_blks : 1; + return (selected_mfma.is_k_reduction ? selected_mfma.num_input_blks : 1) * + selected_mfma.k_per_blk; } - static constexpr index_t GetNumCRegs() { return MPerXdlops * NPerXdlops / mfma_type.wave_size; } + static constexpr index_t GetKPerThread() { return selected_mfma.k_per_blk; } }; -template +template struct XdlopsGemm { - template - static constexpr auto GetXdlopsInfo(); - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - -#if 0 - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } - - template <> - static constexpr auto GetXdlopsInfo() - { - return xdlops_info{}; - } -#endif + static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; + static constexpr auto I2 = Number<2>{}; + static constexpr auto I3 = Number<3>{}; + static constexpr auto I4 = Number<4>{}; + static constexpr auto I5 = Number<5>{}; using CIndex = MultiIndex<2>; - __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } + __device__ static constexpr index_t GetNumBlks() { return mfma_instr.num_output_blks; } __device__ static constexpr index_t GetNumXdlops() { - return MPerXdlops * NPerXdlops / (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); + return MPerXdlops * NPerXdlops / + (mfma_instr.m_per_blk * mfma_instr.n_per_blk * mfma_instr.num_output_blks); } __host__ __device__ constexpr XdlopsGemm() @@ -697,104 +691,141 @@ struct XdlopsGemm MPerXdlops == 64, "Only support GemmMPerXdlops == 4, 8, 16, 32 or 64 for xdlops"); - static_assert(mfma_type.num_threads_blk == mfma_type.n, "n != num_threads_blk"); - static_assert(mfma_type.num_regs_blk * mfma_type.num_input_blks == mfma_type.m, - "m != num_input_blks * num_regs_blk"); - static_assert(mfma_type.num_output_blks == mfma_type.num_input_blks || - mfma_type.num_output_blks == 1, - "incorrect num_output_blks"); - static_assert(mfma_type.num_regs_blk * mfma_type.wave_size == mfma_type.m * mfma_type.n, - "num_regs_blk incorrect"); + static_assert(KPack % mfma_instr.k_per_blk == 0, "KPack cannot be divided by k_per_blk"); + } - static_assert(mfma_type.k % mfma_type.k_base == 0, "k % kbase != 0!"); + template + __host__ __device__ static constexpr auto + MakeCM0N0M1N1M2M3M4N2Descriptor(const CM0N0M1N1M2N2Desc& c_m0_n0_m1_n1_m2_n2_desc) + { + const auto M0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I0); + const auto N0 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I1); + const auto M1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I2); + const auto N1 = c_m0_n0_m1_n1_m2_n2_desc.GetLength(I3); + + return transform_tensor_descriptor( + c_m0_n0_m1_n1_m2_n2_desc, + make_tuple(make_pass_through_transform(M0), + make_pass_through_transform(N0), + make_pass_through_transform(M1), + make_pass_through_transform(N1), + make_unmerge_transform(make_tuple(mfma_instr.num_groups_per_blk, + mfma_instr.num_input_blks, + mfma_instr.group_size)), + make_pass_through_transform(mfma_instr.num_threads_per_blk)), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4>{}, + Sequence<5>{}), + make_tuple(Sequence<0>{}, + Sequence<1>{}, + Sequence<2>{}, + Sequence<3>{}, + Sequence<4, 5, 6>{}, + Sequence<7>{})); } __device__ static constexpr index_t GetRegSizePerXdlops() { - return MPerXdlops * NPerXdlops / mfma_type.wave_size; + return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } - template + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert(is_same::value || is_same::value || is_same::value, "base base_type must be float, half, ushort!"); - static_assert(KPack % mfma_type.k_base == 0, "KPack cannot be divided by k_base"); - - constexpr index_t c_offset = CDesc{}.CalculateOffset(make_tuple(m0, n0)) * GetNumXdlops(); - - static_for<0, KPack, mfma_type.k_base>{}([&](auto k) { - constexpr index_t a_offset = ADesc{}.CalculateOffset(make_tuple(0, m0, 0, k)); - constexpr index_t b_offset = BDesc{}.CalculateOffset(make_tuple(0, n0, 0, k)); - - mfma_type.template run( - p_a_wave[Number{}], - p_b_wave[Number{}], - p_c_thread); + static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { + mfma_instr.template run( + p_a_wave[k], p_b_wave[k], p_c_thread); }); } + __device__ static auto GetLaneId() { return get_thread_local_1d_id() % mfma_instr.wave_size; } + + __device__ static auto GetBlkIdx() + { + const auto laneId = GetLaneId(); + const auto threadidx_to_blk_idx_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform( + make_tuple(1, mfma_instr.num_input_blks, mfma_instr.num_threads_per_blk))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto blk_idx = + threadidx_to_blk_idx_adaptor.CalculateBottomIndex(make_multi_index(laneId)); + + const auto blk_id = blk_idx[I1]; + const auto blk_td = blk_idx[I2]; + + return make_tuple(blk_id, blk_td); + } + + __host__ __device__ static auto CalculateAThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + + __host__ __device__ static auto CalculateBThreadOriginDataIndex() + { + const auto laneId = GetLaneId(); + const auto blk_idx = GetBlkIdx(); + + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + if constexpr(mfma_instr.is_k_reduction) + { + return make_tuple(blk_id, blk_td); + } + else + { + return make_tuple(0, laneId); + } + } + __device__ static CIndex GetBeginOfThreadBlk(index_t xdlops_i, index_t blk_i) { - const index_t laneId = get_thread_local_1d_id() % mfma_type.wave_size; - const index_t blk_id = laneId / mfma_type.num_threads_blk; - const index_t blk_td = laneId % mfma_type.num_threads_blk; + const auto blk_idx = GetBlkIdx(); - index_t n_offset = blk_i * mfma_type.n + blk_td; - index_t m_offset = xdlops_i * mfma_type.m + blk_id * mfma_type.group_size; + const auto blk_id = blk_idx[I0]; + const auto blk_td = blk_idx[I1]; + + index_t n_offset = blk_i * mfma_instr.n_per_blk + blk_td; + index_t m_offset = xdlops_i * mfma_instr.m_per_blk + blk_id * mfma_instr.group_size; return CIndex{m_offset, n_offset}; } - static constexpr index_t MRepeats = GetXdlopsInfo().MRepeats; - static constexpr index_t NRepeats = GetXdlopsInfo().NRepeats; - static constexpr index_t MPerXdlops = GetXdlopsInfo().MPerXdlops; - static constexpr index_t NPerXdlops = GetXdlopsInfo().NPerXdlops; + static constexpr auto mfma = MfmaSelector{}; - static constexpr bool IsKReduction = GetXdlopsInfo().IsKReduction(); - static constexpr bool IsABroadcast = GetXdlopsInfo().IsABroadcast(); - static constexpr index_t KPerXdlops = GetXdlopsInfo().GetKPerXdlops(); + static constexpr auto mfma_instr = mfma.selected_mfma; - static constexpr auto GetBlkId(const index_t lane_id) + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto KPerThread = mfma.GetKPerThread(); + + __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() { - return lane_id / mfma_type.num_threads_blk; + return make_tuple( + Number{}, I1, Number{}, I1); } - - static constexpr auto GetBlkTd(const index_t lane_id) - { - return lane_id % mfma_type.num_threads_blk; - } - - static constexpr auto mfma_type = GetXdlopsInfo().mfma_type; - - struct CLayout - { - __host__ __device__ static constexpr index_t M1() { return mfma_type.num_groups_blk; } - __host__ __device__ static constexpr index_t M0() { return mfma_type.group_size; } - __host__ __device__ static constexpr index_t N1() { return mfma_type.num_input_blks; } - __host__ __device__ static constexpr index_t N0() { return mfma_type.num_threads_blk; } - - __device__ static constexpr index_t GetBlkSize() { return mfma_type.num_regs_blk; } - - __device__ static constexpr index_t GetNumBlks() { return mfma_type.num_output_blks; } - - __device__ static constexpr index_t GetNumXdlops() - { - return MPerXdlops * NPerXdlops / - (mfma_type.m * mfma_type.n * mfma_type.num_output_blks); - } - }; - - __host__ __device__ static constexpr auto GetCLayout() { return CLayout{}; } }; } // namespace ck diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp index 695ffeeb36..34de1ac0ed 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw.hpp @@ -48,10 +48,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( const auto out_n_k_ho_wo_desc = make_naive_tensor_descriptor_packed(out_n_k_ho_wo_lengths); #if 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 + // [M, N, K0, K1] = [128, 128, 4, 8] for fp16 constexpr index_t BlockSize = 256; - constexpr index_t GemmMPerBlock = 256; + constexpr index_t GemmMPerBlock = 128; constexpr index_t GemmNPerBlock = 128; constexpr index_t GemmKPerBlock = 4; @@ -59,10 +59,10 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( constexpr index_t GemmNPerWave = 32; constexpr index_t GemmK1 = 8; - constexpr index_t MRepeat = 4; + constexpr index_t MRepeat = 2; constexpr index_t NRepeat = 2; - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; + using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 2, 8>; using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; @@ -106,22 +106,22 @@ void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nchw_kcyx_nkhw( Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0>{}), + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, + Sequence<0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0>{})); constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp deleted file mode 100644 index 141a326574..0000000000 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk.hpp +++ /dev/null @@ -1,229 +0,0 @@ -#include -#include "device.hpp" -#include "host_tensor.hpp" -#include "transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk.hpp" -#include "driver_gemm_xdlops_v2r2.hpp" - -template -void device_convolution_forward_implicit_gemm_v4r4r2_xdlops_nhwc_kyxc_nhwk( - const InLengths& in_n_hi_wi_c_lengths, - const WeiLengths& wei_k_y_x_c_lengths, - const OutLengths& out_n_ho_wo_k_lengths, - const ConvStrides& conv_strides, - const ConvDilations& conv_dilations, - const InLeftPads& in_left_pads, - const InRightPads& in_right_pads, - const Tensor& in_n_hi_wi_c, - const Tensor& wei_k_y_x_c, - Tensor& out_n_ho_wo_k, - ck::index_t nrepeat) -{ - using namespace ck; - - std::cout << __func__ << std::endl; - - constexpr auto I0 = Number<0>{}; - constexpr auto I1 = Number<1>{}; - constexpr auto I2 = Number<2>{}; - constexpr auto I3 = Number<3>{}; - - DeviceMem in_n_hi_wi_c_device_buf(sizeof(TInWei) * in_n_hi_wi_c.mDesc.GetElementSpace()); - DeviceMem wei_k_y_x_c_device_buf(sizeof(TInWei) * wei_k_y_x_c.mDesc.GetElementSpace()); - DeviceMem out_n_ho_wo_k_device_buf(sizeof(TOut) * out_n_ho_wo_k.mDesc.GetElementSpace()); - - in_n_hi_wi_c_device_buf.ToDevice(in_n_hi_wi_c.mData.data()); - wei_k_y_x_c_device_buf.ToDevice(wei_k_y_x_c.mData.data()); - out_n_ho_wo_k_device_buf.ToDevice(out_n_ho_wo_k.mData.data()); - - const auto in_n_hi_wi_c_desc = make_naive_tensor_descriptor_packed(in_n_hi_wi_c_lengths); - const auto wei_k_y_x_c_desc = make_naive_tensor_descriptor_packed(wei_k_y_x_c_lengths); - const auto out_n_ho_wo_k_desc = make_naive_tensor_descriptor_packed(out_n_ho_wo_k_lengths); - -#if 1 - // [M, N, K0, K1] = [256, 128, 4, 4] for fp32 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 4; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 4>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 4; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 4>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 4; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 4; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#elif 1 - // [M, N, K0, K1] = [256, 128, 4, 8] for fp16 - constexpr index_t BlockSize = 256; - - constexpr index_t GemmMPerBlock = 256; - constexpr index_t GemmNPerBlock = 128; - constexpr index_t GemmKPerBlock = 4; - - constexpr index_t GemmMPerWave = 64; - constexpr index_t GemmNPerWave = 64; - constexpr index_t GemmK1 = 8; - - constexpr index_t MRepeat = 2; - constexpr index_t NRepeat = 1; - - using GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1 = Sequence<1, 4, 8>; - using GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmABlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmABlockTransferDstScalarPerVector_GemmK1 = 8; - - using GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1 = Sequence<1, 2, 8>; - using GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1 = Sequence<4, 64, 1>; - - constexpr index_t GemmBBlockTransferSrcScalarPerVector_GemmK1 = 8; - constexpr index_t GemmBBlockTransferDstScalarPerVector_GemmK1 = 8; - - constexpr index_t GemmCThreadTransferDstScalarPerVector = 4; -#endif - - const auto descs = - transform_forward_convolution_into_gemm_v4r4r2_nhwc_kyxc_nhwk_pad(wei_k_y_x_c_desc, - in_n_hi_wi_c_desc, - out_n_ho_wo_k_desc, - conv_strides, - conv_dilations, - in_left_pads, - in_right_pads, - Number{}); - - const auto wei_gemmk0_gemmm_gemmk1_grid_desc = descs[I0]; - const auto in_gemmk0_gemmn_gemmk1_grid_desc = descs[I1]; - const auto out_gemmm_gemmn_grid_desc = descs[I2]; - - // HACK: hacks that control index calculation when iterating over A, B, C matrix - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_step_hacks = make_tuple( - make_tuple(Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}), - make_tuple( - Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{}, Sequence<0, 0, 0, 0, 0>{})); - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0>{}, - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0>{})); - - constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 1, 0, 0>{}), - make_tuple(Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 0, 0, 0>{}, - Sequence<0, 0, 2, 0, 0>{})); - - constexpr auto wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0>{}; - - constexpr auto in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks = - Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; - - for(index_t i = 0; i < 5; ++i) - { - float ave_time = driver_gemm_xdlops_v2r2< - BlockSize, - TInWei, - TAcc, - TOut, - InMemoryDataOperationEnum_t::Set, - decltype(wei_gemmk0_gemmm_gemmk1_grid_desc), - decltype(in_gemmk0_gemmn_gemmk1_grid_desc), - decltype(out_gemmm_gemmn_grid_desc), - GemmMPerBlock, - GemmNPerBlock, - GemmKPerBlock, - GemmMPerWave, - GemmNPerWave, - MRepeat, - NRepeat, - GemmABlockTransferThreadSliceLengths_GemmK0_GemmM_GemmK1, - GemmABlockTransferThreadClusterLengths_GemmK0_GemmM_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmABlockTransferSrcScalarPerVector_GemmK1, - GemmABlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - GemmBBlockTransferThreadSliceLengths_GemmK0_GemmN_GemmK1, - GemmBBlockTransferThreadClusterLengths_GemmK0_GemmN_GemmK1, - Sequence<1, 0, 2>, - Sequence<1, 0, 2>, - 2, - GemmBBlockTransferSrcScalarPerVector_GemmK1, - GemmBBlockTransferDstScalarPerVector_GemmK1, - false, // don't move back src coordinate after threadwise copy - Sequence<2, 3, 0, 1>, - 2, - GemmCThreadTransferDstScalarPerVector, - decltype(wei_gemmk0_gemmm_gemmk1_grid_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_step_hacks), - decltype(out_m0_m1_m2_n_grid_step_hacks), - decltype(wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks), - decltype(in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks)>( - static_cast(wei_k_y_x_c_device_buf.GetDeviceBuffer()), - static_cast(in_n_hi_wi_c_device_buf.GetDeviceBuffer()), - static_cast(out_n_ho_wo_k_device_buf.GetDeviceBuffer()), - wei_gemmk0_gemmm_gemmk1_grid_desc, - in_gemmk0_gemmn_gemmk1_grid_desc, - out_gemmm_gemmn_grid_desc, - wei_gemmk0_gemmm_gemmk1_grid_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_step_hacks, - out_m0_m1_m2_n_grid_step_hacks, - wei_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks, - in_gemmk0_gemmn_gemmk1_grid_move_slice_window_step_hacks, - nrepeat); - - { - const auto N = out_n_ho_wo_k_lengths[I0]; - const auto K = out_n_ho_wo_k_lengths[I3]; - const auto C = wei_k_y_x_c_lengths[I3]; - - const auto Ho = out_n_ho_wo_k_lengths[I1]; - const auto Wo = out_n_ho_wo_k_lengths[I2]; - - const auto Y = wei_k_y_x_c_lengths[I1]; - const auto X = wei_k_y_x_c_lengths[I2]; - - float perf = (float)(std::size_t(2) * N * K * Ho * Wo * C * Y * X) / - (std::size_t(1000) * 1000 * 1000) / ave_time; - - std::cout << "Average time : " << ave_time << " ms, " << perf << " TFlop/s" - << std::endl; - } - } - - // copy result back to host - out_n_ho_wo_k_device_buf.FromDevice(out_n_ho_wo_k.mData.data()); -} diff --git a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp index 7067291c8a..a392df6aa8 100644 --- a/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp +++ b/host/driver_offline/include/device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk.hpp @@ -250,22 +250,22 @@ void device_convolution_forward_implicit_gemm_v4r4r4_xdlops_nhwc_kyxc_nhwk( Sequence<0, 0, 0, 0, 0>{})); // 2-: GemmK1 constexpr auto out_m0_m1_m2_n_grid_step_hacks = - make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: MRepeat - Sequence<0, 0, 0, 0, 0>{}, // 1+: NRepeat - Sequence<0, 0, 0, 0, 0>{}, // 2+: MWaves - Sequence<0, 0, 0, 0, 0>{}, // 3+: NWaves - Sequence<0, 0, 0, 0, 0>{}, // 4+: M0 - Sequence<0, 0, 0, 0, 0>{}, // 5+: M1 - Sequence<0, 0, 0, 0, 0>{}, // 6+: M2 - Sequence<0, 0, 0, 0, 0>{}), // 7+: N1 - make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: MRepeat - Sequence<0, 0, 0, 0, 0>{}, // 1-: NRepeat - Sequence<0, 0, 0, 0, 0>{}, // 2-: MWaves - Sequence<0, 0, 0, 0, 0>{}, // 3-: NWaves - Sequence<0, 0, 0, 0, 0>{}, // 4-: M0 - Sequence<0, 0, 0, 0, 0>{}, // 5-: M1 - Sequence<0, 0, 0, 0, 0>{}, // 6-: M2 - Sequence<0, 0, 0, 0, 0>{})); // 7-: N1 + make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N1 + make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: MRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: NRepeat + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: MWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: NWaves + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M0 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M1 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M2 + Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N1 constexpr auto in_gemmk0_gemmm_gemmk1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0>{}; diff --git a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp index edfce52a19..a33e5bf4d0 100644 --- a/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp +++ b/host/driver_offline/include/driver_gemm_xdlops_v2r3.hpp @@ -129,9 +129,10 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); } - const auto c_m0_m1_m2_n_grid_desc = GridwiseGemm::MakeCM0M1M2NGridDescriptor(c_m_n_grid_desc); + const auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc = + GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_m_n_grid_desc); - using CM0M1M2NGridDesc = decltype(c_m0_m1_m2_n_grid_desc); + using CM0N0M1N1M2M3M4N2GridDesc = decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); const auto c_block_cluster_adaptor = GridwiseGemm::MakeCBlockClusterAdaptor(c_m_n_grid_desc); @@ -144,7 +145,7 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, FloatC, remove_reference_t, remove_reference_t, - remove_reference_t, + remove_reference_t, remove_reference_t>; #if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE @@ -158,18 +159,18 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_c_grid, a_k0_m_k1_grid_desc, b_k0_n_k1_grid_desc, - c_m0_m1_m2_n_grid_desc, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, c_block_cluster_adaptor); #elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER DeviceMem a_k0_m_k1_grid_desc_dev_buf(sizeof(AK0MK1GridDesc)); DeviceMem b_k0_n_k1_grid_desc_dev_buf(sizeof(BK0NK1GridDesc)); - DeviceMem c_m0_m1_m2_n_grid_desc_dev_buf(sizeof(CM0M1M2NGridDesc)); + DeviceMem c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf(sizeof(CM0N0M1N1M2M3M4N2GridDesc)); DeviceMem c_block_cluster_adaptor_dev_buf(sizeof(CBlockClusterAdaptor)); a_k0_m_k1_grid_desc_dev_buf.ToDevice(&a_k0_m_k1_grid_desc); b_k0_n_k1_grid_desc_dev_buf.ToDevice(&b_k0_n_k1_grid_desc); - c_m0_m1_m2_n_grid_desc_dev_buf.ToDevice(&c_m0_m1_m2_n_grid_desc); + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.ToDevice(&c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc); c_block_cluster_adaptor_dev_buf.ToDevice(&c_block_cluster_adaptor); float ave_time = launch_and_time_kernel( @@ -183,7 +184,8 @@ __host__ float driver_gemm_xdlops_v2r3(const FloatAB* p_a_grid, p_c_grid, cast_pointer_to_constant_address_space(a_k0_m_k1_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(b_k0_n_k1_grid_desc_dev_buf.GetDeviceBuffer()), - cast_pointer_to_constant_address_space(c_m0_m1_m2_n_grid_desc_dev_buf.GetDeviceBuffer()), + cast_pointer_to_constant_address_space( + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc_dev_buf.GetDeviceBuffer()), cast_pointer_to_constant_address_space(c_block_cluster_adaptor_dev_buf.GetDeviceBuffer())); #endif return ave_time; diff --git a/host/driver_offline/src/conv_fwd_driver_offline.cpp b/host/driver_offline/src/conv_fwd_driver_offline.cpp index 32c33003c5..1ab8a822cd 100644 --- a/host/driver_offline/src/conv_fwd_driver_offline.cpp +++ b/host/driver_offline/src/conv_fwd_driver_offline.cpp @@ -24,8 +24,8 @@ #define USE_CONV_FWD_V4R4R2_NHWC 1 #define USE_CONV_FWD_V6R1_NCHW 0 #define USE_CONV_FWD_V5R1_NCHW 0 -#define USE_CONV_FWD_V4R4R2_XDL_NCHW 0 -#define USE_CONV_FWD_V4R4R4_XDL_NHWC 0 +#define USE_CONV_FWD_V4R4R2_XDL_NCHW 1 +#define USE_CONV_FWD_V4R4R4_XDL_NHWC 1 enum ConvForwardAlgo {