diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 81ca2853c6..ca262f0e2a 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -36,6 +36,8 @@ struct ExecutionConfig final int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values) bool time_kernel = false; // (0=no, 1=yes) int verbosity = 0; // (0=no info, 1=verbose info) + int warm_up = 10; + int repeat = 10; }; struct ProblemSizeSplitK final @@ -86,6 +88,8 @@ bool parse_cmd_args(int argc, if(argc >= 12) { problem_size.KBatch = std::stoi(argv[11]); + config.warm_up = std::stoi(argv[12]); + config.repeat = std::stoi(argv[13]); } } else @@ -411,8 +415,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "Computing GEMM on device..." << std::endl << std::endl; } - float ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); + float ave_time = invoker.Run( + argument, + StreamConfig{nullptr, config.time_kernel, config.verbosity, config.warm_up, config.repeat}); bool res_verified = true; if(config.do_verification > 0) @@ -484,14 +489,15 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c // partial sums(K/ScaleBlockSize)] // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; - std::size_t num_btype = sizeof(ADataType) * M * K / ck::pack_size_v + - sizeof(BDataType) * K * N / ck::pack_size_v + + std::size_t num_btype = sizeof(ADataType) * M * K / ck::pack_size_v + // + sizeof(BDataType) * K * N / ck::pack_size_v + // sizeof(CDataType) * M * N + - sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; + sizeof(XDataType) * M * K / ScaleBlockSize + + sizeof(XDataType) * N * K / ScaleBlockSize; float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; + float gb_per_sec = static_cast(num_btype) / 1e6f / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << device_op.GetTypeString() << std::endl; diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index c024133adf..6854eaafab 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -64,15 +64,24 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; static constexpr auto block_slice_lengths = BlockSliceLengths{}; static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; + static constexpr auto wave_thread_cluster_lengths = + Sequence{}; + static constexpr auto wave_cluster_lengths = + Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{}; static constexpr auto thread_single_load_size = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); // After a load, each thread moves by `thread_steps` instead of loading the next elements. // It makes the whole wavefront load contiguous memory, what is required for direct loads. - static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; static __device__ constexpr bool AreThreadClusterLengthsValid() @@ -171,10 +180,16 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() / 64)); + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size; SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); - SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); + // We don't need threadwise offset for lds since it was calculate by HW + // We still need input the wavewise offset. + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -222,7 +237,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad // Loop over the destination block and copy data. static_ford{}([&](auto ordered_dst_access_idx) { const auto src_offset = src_coord_.GetOffset(); - const auto dst_offset = dst_coord_.GetOffset(); + const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset()); // Check if src data is not in the logic padding area. const bool is_src_valid = @@ -312,6 +327,8 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad private: static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + static constexpr auto wave_cluster_desc_ = + make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); SrcCoord src_coord_; DstCoord dst_coord_; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index 975d236aa4..89cafe2fca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -1895,8 +1895,9 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto b_scale_grid_buf = make_dynamic_buffer( p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; + static_assert( + is_same_v && + is_same_v); const CElementwiseOperation c_element_op{}; // divide block work by [M, N]