diff --git a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp index a8236737df..36c6783204 100644 --- a/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp +++ b/composable_kernel/include/tensor_operation/blockwise_gemm_xdlops.hpp @@ -10,6 +10,7 @@ namespace ck { template {}; static constexpr index_t MWaves = MPerBlock / (MRepeat * MPerXDL); static constexpr index_t NWaves = NPerBlock / (NRepeat * NPerXDL); + StaticBufferV2, MRepeat * NRepeat, true> + c_thread_buf_; + + __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } + __device__ static auto GetWaveIdx() { const index_t thread_id = get_thread_local_1d_id(); @@ -162,7 +167,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { return transform_tensor_descriptor( AK0MK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), + make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -174,7 +179,7 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 { return transform_tensor_descriptor( BK0NK1BlockDesc{}, - make_tuple(make_pass_through_transform(Number{}), + make_tuple(make_pass_through_transform(Number{}), make_unmerge_transform( make_tuple(Number{}, Number{}, Number{})), make_pass_through_transform(Number{})), @@ -195,48 +200,43 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 auto b_thread_buf = make_static_buffer( b_thread_desc_.GetElementSpaceSize()); - vector_type a_thread_vec; - - vector_type b_thread_vec; - - static_for<0, KPerBlock, xdlops_gemm.KPerXdlops / xdlops_gemm.KPerThread>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { // read A a_thread_copy_.Run(a_k0_m0_m1_m2_k1_block_desc, - make_tuple(k0, I0, I0, I0, I0), + make_tuple(I0, m0, I0, I0, I0), a_block_buf, a_thread_desc_, make_tuple(I0, I0, I0, I0, I0), a_thread_buf); - // read B - b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, - make_tuple(k0, I0, I0, I0, I0), - b_block_buf, - b_thread_desc_, - make_tuple(I0, I0, I0, I0, I0), - b_thread_buf); + static_for<0, NRepeat, 1>{}([&](auto n0) { + // read B + b_thread_copy_.Run(b_k0_n0_n1_n2_k1_block_desc, + make_tuple(I0, n0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, I0, I0, I0, I0), + b_thread_buf); - using mfma_input_type = typename vector_type::type; + static_for<0, K0, xdlops_gemm.K0PerXdlops>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; - static_for<0, MRepeat, 1>{}([&](auto m0) { - static_for<0, NRepeat, 1>{}([&](auto n0) { static_for<0, K1, 1>{}([&](auto i) { a_thread_vec.template AsType()(i) = a_thread_buf - [Number{}]; - }); - - static_for<0, K1, 1>{}([&](auto i) { + [Number{}]; b_thread_vec.template AsType()(i) = b_thread_buf - [Number{}]; + [Number{}]; }); - constexpr index_t c_offset = - c_thread_desc_.CalculateOffset(make_tuple(m0, n0, 0)); + using mfma_input_type = + typename vector_type::type; - xdlops_gemm.template Run( - a_thread_vec.template AsType(), - b_thread_vec.template AsType(), - c_thread_buf); + constexpr index_t c_offset = c_thread_desc_.CalculateOffset(make_tuple(m0, n0)); + + xdlops_gemm.template Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_thread_buf.GetVector(Number{})); }); }); }); @@ -244,35 +244,35 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1 private: // A[K, M] - static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, I1, I1, Number{})); + static constexpr auto a_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); // B[K, N] - static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(I1, Number{}, I1, I1, Number{})); + static constexpr auto b_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, I1, I1, I1, Number{})); - static constexpr auto c_thread_desc_ = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{}, Number{})); + static constexpr auto c_thread_desc_ = + make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); using AThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence, Sequence<0, 1, 2, 3, 4>, 4, K1, - 1>; + K1>; using BThreadCopy = ThreadwiseTensorSliceTransfer_v4, + Sequence, Sequence<0, 1, 2, 3, 4>, 4, K1, - 1>; + K1>; AThreadCopy a_thread_copy_{CalculateAThreadOriginDataIndex()}; BThreadCopy b_thread_copy_{CalculateBThreadOriginDataIndex()}; diff --git a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp index 3e4d74e9d8..c6f491dc47 100644 --- a/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp +++ b/composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r3.hpp @@ -142,6 +142,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 static constexpr auto I4 = Number<4>{}; static constexpr auto I5 = Number<5>{}; static constexpr auto I6 = Number<6>{}; + static constexpr auto I7 = Number<7>{}; // K1 should be Number<...> static constexpr auto K1 = Number{}; @@ -220,6 +221,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 using BlockwiseGemm = BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1{}; - constexpr auto c_mr_nr_blk_desc = - make_naive_tensor_descriptor_packed(make_tuple(Number{}, Number{})); - - constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = - blockwise_gemm.GetCM0N0M1N1M2M3M4N2ThreadDescriptor(); - constexpr auto CBlkSize = c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc.GetElementSpaceSize(); - - StaticBuffer, - c_mr_nr_blk_desc.GetElementSpaceSize(), - true> - c_thread_buf; + auto c_thread_buf = blockwise_gemm.GetCThreadBuffer(); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size = @@ -460,9 +452,18 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc = blockwise_gemm.GetCM0N0M1N1M2M3M4N2BlockDescriptor(); + constexpr auto M0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I0); + constexpr auto N0 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I1); + constexpr auto M1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I2); + constexpr auto N1 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I3); constexpr auto M2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I4); constexpr auto M3 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I5); constexpr auto M4 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I6); + constexpr auto N2 = c_m0_n0_m1_n1_m2_m3_m4_n2_block_desc.GetLength(I7); + + constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc = + make_naive_tensor_descriptor_packed(make_tuple( + Number{}, Number{}, I1, I1, Number{}, I1, Number{}, I1)); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -477,224 +478,54 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks = CGridStepHacks{}; + const auto m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor = + make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), + make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(Sequence<0>{})); + + const auto m_thread_data_on_grid_idx = + m_thread_data_on_grid_to_m0_m1_m2_m3_m4_adaptor.CalculateBottomIndex( + make_multi_index(m_thread_data_on_grid)); + + const auto n_thread_data_on_grid_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( + make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), + make_tuple(Sequence<0, 1, 2>{}), + make_tuple(Sequence<0>{})); + + const auto n_thread_data_on_grid_idx = + n_thread_data_on_grid_to_n0_n1_n2_adaptor.CalculateBottomIndex( + make_multi_index(n_thread_data_on_grid)); + auto c_thread_copy = - ThreadwiseTensorSliceTransfer_v1r3, + Sequence, CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, CGlobalMemoryDataOperation, 1, true>{ + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - make_multi_index(0, - 0, - 0, - 0, - m_thread_data_on_grid / (M3 * M4), - m_thread_data_on_grid % (M3 * M4) / M4, - m_thread_data_on_grid % M4, - n_thread_data_on_grid)}; + make_multi_index(m_thread_data_on_grid_idx[I0], + n_thread_data_on_grid_idx[I0], + m_thread_data_on_grid_idx[I1], + n_thread_data_on_grid_idx[I1], + m_thread_data_on_grid_idx[I2], + m_thread_data_on_grid_idx[I3], + m_thread_data_on_grid_idx[I4], + n_thread_data_on_grid_idx[I2])}; - auto init_copy = [&](auto c_thread_idx_) { - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - - return c_thread_idx_; - }; - - auto mrepeat_plus_copy = [&](auto c_thread_idx_) { - constexpr auto mrepeat_step_plus = make_multi_index(1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - mrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - }; - - auto nrepeat_plus_copy = [&](auto c_thread_idx_) { - constexpr auto nrepeat_step_plus = make_multi_index(0, 1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - nrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - }; - - auto mrepeat_minus_copy = [&](auto c_thread_idx_) { - constexpr auto mrepeat_step_plus = make_multi_index(-1, 0, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - mrepeat_step_plus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - }; - - auto nrepeat_minus_copy = [&](auto c_thread_idx_) { - constexpr auto nrepeat_step_minus = make_multi_index(0, -1, 0, 0, 0, 0, 0, 0); - c_thread_copy.MoveDstSliceWindow(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - nrepeat_step_minus); - - constexpr auto blk_off = c_mr_nr_blk_desc.CalculateOffset(c_thread_idx_); - c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, - make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), - c_thread_buf[Number{}].template AsType(), - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, - c_grid_buf, - c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); - }; - - static_assert((MRepeat == 4 && NRepeat == 4) or (MRepeat == 4 && NRepeat == 2) or - (MRepeat == 2 && NRepeat == 4) or (MRepeat == 2 && NRepeat == 2) or - (MRepeat == 2 && NRepeat == 1) or (MRepeat == 1 && NRepeat == 2) or - (MRepeat == 1 && NRepeat == 1), - "wrong"); - - if constexpr(MRepeat == 4 && NRepeat == 4) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - nrepeat_plus_copy(make_tuple(I0, I3)); - mrepeat_plus_copy(make_tuple(I1, I3)); - nrepeat_minus_copy(make_tuple(I1, I2)); - nrepeat_minus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - nrepeat_plus_copy(make_tuple(I2, I1)); - nrepeat_plus_copy(make_tuple(I2, I2)); - nrepeat_plus_copy(make_tuple(I2, I3)); - mrepeat_plus_copy(make_tuple(I3, I3)); - nrepeat_minus_copy(make_tuple(I3, I2)); - nrepeat_minus_copy(make_tuple(I3, I1)); - nrepeat_minus_copy(make_tuple(I3, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - mrepeat_plus_copy(make_tuple(I3, I0)); - nrepeat_plus_copy(make_tuple(I3, I1)); - mrepeat_minus_copy(make_tuple(I2, I1)); - mrepeat_minus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - mrepeat_plus_copy(make_tuple(I1, I2)); - mrepeat_plus_copy(make_tuple(I2, I2)); - mrepeat_plus_copy(make_tuple(I3, I2)); - nrepeat_plus_copy(make_tuple(I3, I3)); - mrepeat_minus_copy(make_tuple(I2, I3)); - mrepeat_minus_copy(make_tuple(I1, I3)); - mrepeat_minus_copy(make_tuple(I0, I3)); - } - } - else if constexpr(MRepeat == 4 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - mrepeat_plus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - nrepeat_plus_copy(make_tuple(I2, I1)); - mrepeat_plus_copy(make_tuple(I3, I1)); - nrepeat_minus_copy(make_tuple(I3, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - mrepeat_plus_copy(make_tuple(I2, I0)); - mrepeat_plus_copy(make_tuple(I3, I0)); - nrepeat_plus_copy(make_tuple(I3, I1)); - mrepeat_minus_copy(make_tuple(I2, I1)); - mrepeat_minus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 4) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - nrepeat_plus_copy(make_tuple(I0, I3)); - mrepeat_plus_copy(make_tuple(I1, I3)); - nrepeat_minus_copy(make_tuple(I1, I2)); - nrepeat_minus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - nrepeat_plus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - nrepeat_plus_copy(make_tuple(I0, I2)); - mrepeat_plus_copy(make_tuple(I1, I2)); - nrepeat_plus_copy(make_tuple(I1, I3)); - mrepeat_minus_copy(make_tuple(I0, I3)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - - if constexpr(CAccessOrderMRepeatNRepeat) - { - nrepeat_plus_copy(make_tuple(I0, I1)); - mrepeat_plus_copy(make_tuple(I1, I1)); - nrepeat_minus_copy(make_tuple(I1, I0)); - } - else - { - mrepeat_plus_copy(make_tuple(I1, I0)); - nrepeat_plus_copy(make_tuple(I1, I1)); - mrepeat_minus_copy(make_tuple(I0, I1)); - } - } - else if constexpr(MRepeat == 2 && NRepeat == 1) - { - init_copy(make_tuple(I0, I0)); - mrepeat_plus_copy(make_tuple(I1, I0)); - } - else if constexpr(MRepeat == 1 && NRepeat == 2) - { - init_copy(make_tuple(I0, I0)); - nrepeat_plus_copy(make_tuple(I0, I1)); - } - else if constexpr(MRepeat == 1 && NRepeat == 1) - { - init_copy(make_tuple(I0, I0)); - } + c_thread_copy.Run(c_m0_n0_m1_n1_m2_m3_m4_n2_thread_desc, + make_tuple(I0, I0, I0, I0, I0, I0, I0, I0), + c_thread_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_desc, + c_grid_buf, + c_m0_n0_m1_n1_m2_m3_m4_n2_grid_tensor_step_hacks); } } }; // namespace ck diff --git a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp index f945b0fdf5..10633f8f32 100644 --- a/composable_kernel/include/tensor_operation/xdlops_gemm.hpp +++ b/composable_kernel/include/tensor_operation/xdlops_gemm.hpp @@ -44,15 +44,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); + intrin_mfma_f32_32x32x1f32::Run(a, b, reg_c); } }; @@ -71,15 +66,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); + intrin_mfma_f32_32x32x2f32::Run(a, b, reg_c); } }; @@ -98,15 +88,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); + intrin_mfma_f32_16x16x4f32::Run(a, b, reg_c); } }; @@ -125,15 +110,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); + intrin_mfma_f32_16x16x1f32::Run(a, b, reg_c); } }; @@ -153,15 +133,10 @@ struct mfma_type static constexpr index_t k_per_blk = 1; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); + intrin_mfma_f32_4x4x1f32::Run(a, b, reg_c); } }; @@ -180,15 +155,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); + intrin_mfma_f32_32x32x4f16::Run(a, b, reg_c); } }; @@ -207,15 +177,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); + intrin_mfma_f32_32x32x8f16::Run(a, b, reg_c); } }; @@ -234,15 +199,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = true; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); + intrin_mfma_f32_16x16x16f16::Run(a, b, reg_c); } }; @@ -261,15 +221,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); + intrin_mfma_f32_16x16x4f16::Run(a, b, reg_c); } }; @@ -288,15 +243,10 @@ struct mfma_type static constexpr index_t k_per_blk = 4; static constexpr bool is_k_reduction = false; - template + template __device__ void run(const FloatA& a, const FloatB& b, FloatC& reg_c) const { - intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); + intrin_mfma_f32_4x4x4f16::Run(a, b, reg_c); } }; @@ -732,7 +682,7 @@ struct XdlopsGemm return MPerXdlops * NPerXdlops / mfma_instr.wave_size; } - template + template __device__ void Run(const FloatA& p_a_wave, const FloatB& p_b_wave, FloatC& p_c_thread) const { static_assert(is_same::value || is_same::value || @@ -740,8 +690,7 @@ struct XdlopsGemm "base base_type must be float, half, ushort!"); static_for<0, KPack / mfma_instr.k_per_blk, 1>{}([&](auto k) { - mfma_instr.template run( - p_a_wave[k], p_b_wave[k], p_c_thread); + mfma_instr.template run(p_a_wave[k], p_b_wave[k], p_c_thread); }); } @@ -819,8 +768,9 @@ struct XdlopsGemm static constexpr auto mfma_instr = mfma.selected_mfma; - static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto KPerThread = mfma.GetKPerThread(); + static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); + static constexpr auto K1PerXdlops = mfma.GetKPerThread(); + static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; __host__ __device__ static constexpr auto GetCM0M1M2NThreadBlkLengths() { diff --git a/composable_kernel/include/utility/amd_xdlops.hpp b/composable_kernel/include/utility/amd_xdlops.hpp index da74fe1d48..083e47fbf1 100644 --- a/composable_kernel/include/utility/amd_xdlops.hpp +++ b/composable_kernel/include/utility/amd_xdlops.hpp @@ -51,304 +51,196 @@ extern "C" __device__ float16_t llvm_intrin_amdgcn_mfma_f32_16x16x2bf16( extern "C" __device__ float4_t llvm_intrin_amdgcn_mfma_f32_4x4x2bf16( ushort2_t, ushort2_t, float4_t, int, int, int) __asm("llvm.amdgcn.mfma.f32.4x4x2bf16"); -template +template struct intrin_mfma_f32_32x32x1f32; -template -struct intrin_mfma_f32_32x32x1f32<64, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x1f32<64, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); } }; -template -struct intrin_mfma_f32_32x32x1f32<32, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x1f32<32, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); } }; -template +template struct intrin_mfma_f32_32x32x2f32; -template -struct intrin_mfma_f32_32x32x2f32<32, 32, COffset> +template <> +struct intrin_mfma_f32_32x32x2f32<32, 32> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x2f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x2f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x4f32; -template -struct intrin_mfma_f32_16x16x4f32<16, 16, COffset> +template <> +struct intrin_mfma_f32_16x16x4f32<16, 16> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x4f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x1f32; -template -struct intrin_mfma_f32_16x16x1f32<16, 64, COffset> +template <> +struct intrin_mfma_f32_16x16x1f32<16, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 2, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); } }; -template +template struct intrin_mfma_f32_4x4x1f32; -template -struct intrin_mfma_f32_4x4x1f32<4, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x1f32<4, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); } }; -template -struct intrin_mfma_f32_4x4x1f32<8, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x1f32<8, 64> { template __device__ static void Run(const float& reg_a, const float& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x1f32( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x1f32( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); } }; -template +template struct intrin_mfma_f32_32x32x4f16; -template -struct intrin_mfma_f32_32x32x4f16<64, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x4f16<64, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 1, 1, 0); } }; -template -struct intrin_mfma_f32_32x32x4f16<32, 64, COffset> +template <> +struct intrin_mfma_f32_32x32x4f16<32, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 1, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 1, 0, 0); } }; -template +template struct intrin_mfma_f32_32x32x8f16; -template -struct intrin_mfma_f32_32x32x8f16<32, 32, COffset> +template <> +struct intrin_mfma_f32_32x32x8f16<32, 32> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_32x32x8f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_32x32x8f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x16f16; -template -struct intrin_mfma_f32_16x16x16f16<16, 16, COffset> +template <> +struct intrin_mfma_f32_16x16x16f16<16, 16> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x16f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 0, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x16f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 0, 0, 0); } }; -template +template struct intrin_mfma_f32_16x16x4f16; -template -struct intrin_mfma_f32_16x16x4f16<16, 64, COffset> +template <> +struct intrin_mfma_f32_16x16x4f16<16, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_16x16x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 2, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_16x16x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 2, 0, 0); } }; -template +template struct intrin_mfma_f32_4x4x4f16; -template -struct intrin_mfma_f32_4x4x4f16<4, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x4f16<4, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); } }; -template -struct intrin_mfma_f32_4x4x4f16<8, 64, COffset> +template <> +struct intrin_mfma_f32_4x4x4f16<8, 64> { template __device__ static void Run(const half4_t& reg_a, const half4_t& reg_b, FloatC& reg_c) { - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 0, - 0); - reg_c(Number{}).template AsType()(Number<0>{}) = - llvm_intrin_amdgcn_mfma_f32_4x4x4f16( - reg_a, - reg_b, - reg_c[Number{}].template AsType()[Number<0>{}], - 4, - 1, - 0); + reg_c.template AsType()(Number<0>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<0>{}], 4, 0, 0); + reg_c.template AsType()(Number<1>{}) = llvm_intrin_amdgcn_mfma_f32_4x4x4f16( + reg_a, reg_b, reg_c.template AsType()[Number<1>{}], 4, 1, 0); } }; @@ -448,7 +340,6 @@ template __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16(const ushort2_t* reg_a, const ushort2_t* reg_b, c_vec16_1_t::VecType reg_c); - template <> __device__ c_vec16_1_t::VecType intrin_mfma_f32_16x16x2bf16<16, 64>(const ushort2_t* reg_a, const ushort2_t* reg_b, diff --git a/composable_kernel/include/utility/static_buffer.hpp b/composable_kernel/include/utility/static_buffer.hpp index cd67b8a0be..9615d10c59 100644 --- a/composable_kernel/include/utility/static_buffer.hpp +++ b/composable_kernel/include/utility/static_buffer.hpp @@ -55,6 +55,98 @@ struct StaticBuffer : public StaticallyIndexedArray __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } }; +template +struct StaticBufferV2 : public StaticallyIndexedArray +{ + using type = T; + using base = StaticallyIndexedArray; + + using VecBaseType = typename T::d1_t; + + __host__ __device__ static constexpr index_t GetVectorSize() + { + return sizeof(typename T::type) / sizeof(VecBaseType); + } + + static constexpr index_t vector_size = GetVectorSize(); + + VecBaseType invalid_element_value_ = VecBaseType{0}; + + T invalid_vec_value_ = T{0}; + + __host__ __device__ constexpr StaticBufferV2() : base{} {} + + __host__ __device__ constexpr StaticBufferV2(VecBaseType invalid_element_value) + : base{}, + invalid_vec_value_{invalid_element_value}, + invalid_element_value_{invalid_element_value} + { + } + + __host__ __device__ static constexpr AddressSpaceEnum_t GetAddressSpace() + { + return BufferAddressSpace; + } + + template + __host__ __device__ constexpr auto& GetVector(Number vec_id) + { + return this->At(vec_id); + } + + template + __host__ __device__ constexpr const auto& GetVector(Number vec_id) const + { + return this->At(vec_id); + } + + template + __host__ __device__ constexpr auto& GetElement(Number i, bool) + { + constexpr auto vec_id = Number{}; + constexpr auto vec_off = Number{}; + + return this->At(vec_id).template AsType()(vec_off); + } + + template + __host__ __device__ constexpr auto GetElement(Number i, bool is_valid_element) const + { + constexpr auto vec_id = Number{}; + constexpr auto vec_off = Number{}; + + if constexpr(InvalidElementUseNumericalZeroValue) + { + return is_valid_element ? this->At(vec_id).template AsType()[vec_off] + : VecBaseType{0}; + } + else + { + return is_valid_element ? this->At(vec_id).template AsType()[vec_off] + : invalid_element_value_; + } + } + + template + __host__ __device__ constexpr auto operator[](Number i) const + { + return GetElement(i, true); + } + + template + __host__ __device__ constexpr auto& operator()(Number i) + { + return GetElement(i, true); + } + + __host__ __device__ static constexpr bool IsStaticBuffer() { return true; } + + __host__ __device__ static constexpr bool IsDynamicBuffer() { return false; } +}; + template __host__ __device__ constexpr auto make_static_buffer(Number) {