diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 85509ec3bd..5e15ccd04f 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -36,6 +36,8 @@ struct ExecutionConfig final int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values) bool time_kernel = false; // (0=no, 1=yes) int verbosity = 0; // (0=no info, 1=verbose info) + int warm_up = 10; + int repeat = 10; }; struct ProblemSizeSplitK final @@ -86,6 +88,8 @@ bool parse_cmd_args(int argc, if(argc >= 12) { problem_size.KBatch = std::stoi(argv[11]); + config.warm_up = std::stoi(argv[12]); + config.repeat = std::stoi(argv[13]); } } else @@ -282,22 +286,13 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c // ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); // ck::utils::FillConstant{b_data_element(1.0f)}(b_k_n); - if constexpr(ck::is_same_v) - { - a_m_k_scale.GenerateTensorValue( - GeneratorTensor_2{120, 129}); // scales: {0.25, 0.5, 1, 2} - b_k_n_scale.GenerateTensorValue( - GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - // ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); - // ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); - } - else - { - ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(a_m_k_scale); - ck::utils::FillUniformDistributionIntegerValue{-1.0f, 1.0f}(b_k_n_scale); - // ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); - // ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n_scale); - } + static_assert(ck::is_same_v); + a_m_k_scale.GenerateTensorValue( + GeneratorTensor_2{120, 129}); // scales: {0.25, 0.5, 1, 2} + b_k_n_scale.GenerateTensorValue( + GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} + // ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); + // ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); break; @@ -420,8 +415,9 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "Computing GEMM on device..." << std::endl << std::endl; } - float ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); + float ave_time = invoker.Run( + argument, + StreamConfig{nullptr, config.time_kernel, config.verbosity, config.warm_up, config.repeat}); bool res_verified = true; if(config.do_verification > 0) @@ -493,14 +489,14 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c // partial sums(K/ScaleBlockSize)] // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; - std::size_t num_btype = sizeof(ADataType) * M * K / ck::packed_size_v + - sizeof(BDataType) * K * N / ck::packed_size_v + - sizeof(CDataType) * M * N + - sizeof(XDataType) * (M * K + K * N) / ScaleBlockSize; + std::size_t num_btype = + sizeof(ADataType) * M * K / ck::packed_size_v + + sizeof(BDataType) * K * N / ck::packed_size_v + sizeof(CDataType) * M * N + + sizeof(XDataType) * M * K / ScaleBlockSize + sizeof(XDataType) * N * K / ScaleBlockSize; float tflops = static_cast(flop) / 1.E9 / ave_time; - float gb_per_sec = num_btype / 1.E6 / ave_time; + float gb_per_sec = static_cast(num_btype) / 1e6f / ave_time; std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << device_op.GetTypeString() << std::endl; diff --git a/example/67_gemm_microscaling/gemm_mx_fp4.cpp b/example/67_gemm_microscaling/gemm_mx_fp4.cpp index b3b6345871..cff5148fa7 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4.cpp @@ -23,8 +23,8 @@ using AElementOp = PassThrough; // elementwise transformation for A matrix using BElementOp = PassThrough; // elementwise transformation for B matrix using CElementOp = PassThrough; // elementwise transformation for C matrix -constexpr ck::index_t DataPackedSize = 2; // Packed representation of data -constexpr ck::index_t ScaleBlockSize = 32; // scaling block size +constexpr ck::index_t DataPackedSize = 2; // Packed representation of data +constexpr ck::index_t ScaleBlockSize = 32; // scaling block size constexpr ck::index_t KPerBlock = 256 / DataPackedSize; // 256 f4 = 128 fp4x2 constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; @@ -50,14 +50,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle GemmSpec, // GemmSpec ScaleBlockSize, // ScaleBlockSize: Scaling block size 256, // BlockSize: Thread block size - 192, // MPerBlock + 256, // MPerBlock 256, // NPerBlock KPerBlock, // KPerBlock 16, // AK1 16, // BK1 16, // MPerXDL 16, // NPerXDL - 6, // MXdlPerWave + 8, // MXdlPerWave 8, // NXdlPerWave S<8, 32, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1 S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder @@ -65,14 +65,14 @@ using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffle 2, // ABlockTransferSrcVectorDim 16, // ABlockTransferSrcScalarPerVector 16, // ABlockTransferDstScalarPerVector_AK1 - false, // ABlockLdsExtraM + true, // ABlockLdsExtraM S<8, 32, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1 S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder S<1, 0, 2>, // BBlockTransferSrcAccessOrder 2, // BBlockTransferSrcVectorDim 16, // BBlockTransferSrcScalarPerVector 16, // BBlockTransferDstScalarPerVector_BK1 - false, // BBlockLdsExtraN + true, // BBlockLdsExtraN 2, // CShuffleMXdlPerWavePerShuffle 2, // CShuffleNXdlPerWavePerShuffle S<1, 32, 1, 8>, // CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp index 14bed63bcb..2b5b5d3305 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx.hpp @@ -203,8 +203,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto i) { if constexpr(i < mfma_stages_more) { - static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma < num_dswrite_per_issue_a) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } else { - static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma < num_dswrite_per_issue_a) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } @@ -274,23 +266,15 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto i) { if constexpr((i + num_buffer_load_inst_a) < mfma_stages_more) { - static_for<0, mfma_perstage_more, 1>{}([&](auto imfma) { + static_for<0, mfma_perstage_more, 1>{}([&](auto /*imfma*/) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma < num_dswrite_per_issue_a) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } else { - static_for<0, mfma_perstage_less, 1>{}([&](auto imfma) { + static_for<0, mfma_perstage_less, 1>{}([&](auto /*imfma*/) { __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA - if constexpr(imfma < num_dswrite_per_issue_b) - { - __builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write - } }); __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read } @@ -392,14 +376,14 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}> b_scale_thread_bufs; // Global prefetch 1 - a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); - b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + a_blockwise_copy.Run(a_grid_desc, a_grid_buf, a_block_desc, a_block_bufs(I0)); + b_blockwise_copy.Run(b_grid_desc, b_grid_buf, b_block_desc, b_block_bufs(I0)); a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); @@ -476,18 +460,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; @@ -502,7 +476,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}, I0, Number{}), - a_block_buf, + a_block_bufs(I0), a_thread_desc_, make_tuple(Number{}, I0, @@ -524,7 +498,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}, I0, Number{}), - b_block_buf, + b_block_bufs(I0), b_thread_desc_, make_tuple(Number{}, I0, @@ -536,6 +510,13 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto m0) { @@ -699,7 +680,8 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx 47 96 --> 111| 160 --> 175 224 --> 239| etc. // t48: |48 --> 63 112 --> 127| 176 --> 191 240 --> 255| etc. // k = 0 k = 1 - block_sync_lds(); + // __builtin_amdgcn_s_waitcnt(3952); + // block_sync_lds(); static_for<0, KRepeat, 1>{}([&](auto k) { constexpr auto k_step = k * xdlops_gemm.KPerXdlops * KPack / xdlops_gemm.K1PerXdlops; @@ -716,7 +698,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}, I0, Number{}), - a_block_buf, + a_block_bufs(scale_mem_buf), a_thread_desc_, make_tuple(Number{}, I0, @@ -740,7 +722,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}, I0, Number{}), - b_block_buf, + b_block_bufs(scale_mem_buf), b_thread_desc_, make_tuple(Number{}, I0, @@ -798,10 +780,6 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto m0) { static_for<0, NRepeat / NXdlPack, 1>{}([&](auto n0) { static_for<0, KRepeat / KXdlPack, 1>{}([&](auto k0) { @@ -880,6 +858,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}([&](auto k) { @@ -897,7 +876,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}, I0, Number{}), - a_block_buf, + a_block_bufs(I1), a_thread_desc_, make_tuple(Number{}, I0, @@ -920,7 +899,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_mx{}, I0, Number{}), - b_block_buf, + b_block_bufs(I1), b_thread_desc_, make_tuple(Number{}, I0, diff --git a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp index a737c9195b..6854eaafab 100644 --- a/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp +++ b/include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once @@ -42,13 +42,16 @@ namespace ck { template + index_t ScalarPerVector, + bool SrcXor = true> struct ThreadGroupTensorSliceTransfer_DirectLoad { static constexpr index_t nDim = remove_reference_t::GetNumOfDimension(); @@ -61,15 +64,24 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad using DstCoordStep = decltype(make_tensor_coordinate_step(DstDesc{}, Index{})); static constexpr auto I0 = Number<0>{}; + static constexpr auto I1 = Number<1>{}; static constexpr auto block_slice_lengths = BlockSliceLengths{}; static constexpr auto thread_cluster_lengths = ThreadClusterLengths{}; + static constexpr auto wave_thread_cluster_lengths = + Sequence{}; + static constexpr auto wave_cluster_lengths = + Sequence<1, ThreadGroup::GetNumOfThread() / 64, 1>{}; static constexpr auto thread_single_load_size = generate_sequence( detail::lambda_scalar_per_access{}, Number{}); // After a load, each thread moves by `thread_steps` instead of loading the next elements. // It makes the whole wavefront load contiguous memory, what is required for direct loads. - static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto thread_steps = thread_cluster_lengths * thread_single_load_size; + static constexpr auto wave_single_load_size = + wave_thread_cluster_lengths * thread_single_load_size; static constexpr auto thread_slice_lengths = block_slice_lengths / thread_steps; static __device__ constexpr bool AreThreadClusterLengthsValid() @@ -96,8 +108,12 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad // VALID: ThreadClusterLengths = [4, 16, 4] or [2, 32, 4] or [1, 64, 4] since in the // first iteration, threads 0-63 write [0, 0, 0] - [0, 15, 7] -> 128 consecutive // elements = 64 consecutive DWORDs. +#if defined(__gfx950__) + int num_contiguous_dwords = 4; +#else int num_contiguous_dwords = 1; - bool is_contiguous = true; +#endif + bool is_contiguous = true; static_for<0, nDim, 1>{}([&](auto i) { if(is_contiguous) { @@ -141,11 +157,11 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad "When loading more than one element per thread at once, the contiguous " "dimension must be the same between source and destination."); - constexpr auto dword_bytes = 4; - constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); - static_assert(bytes_per_thread_load == dword_bytes, - "Direct load transfer requires each thread to load exactly a single " - "DWORD of data."); + // constexpr auto dword_bytes = 4; + // constexpr auto bytes_per_thread_load = ScalarPerVector * sizeof(SrcData); + // static_assert(bytes_per_thread_load == dword_bytes, + // "Direct load transfer requires each thread to load exactly a single " + // "DWORD of data."); static_assert(nDim == remove_cvref_t::GetNumOfDimension() && nDim == remove_cvref_t::GetNumOfDimension() && @@ -156,18 +172,24 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad "The number of threads cannot be less than the number of elements in " "thread cluster lengths."); - static_assert( - AreThreadClusterLengthsValid(), - "Thread cluster lengths are incorrect. They must be set in a way that allows a single " - "wavefront to write contiguous DWORDs into LDS memory. "); + // static_assert( + // AreThreadClusterLengthsValid(), + // "Thread cluster lengths are incorrect. They must be set in a way that allows a single + // " "wavefront to write contiguous DWORDs into LDS memory. "); const auto thread_cluster_idx = thread_cluster_desc_.CalculateBottomIndex(make_multi_index(ThreadGroup::GetThreadId())); + const auto wave_cluster_idx = wave_cluster_desc_.CalculateBottomIndex( + make_multi_index(ThreadGroup::GetThreadId() / 64)); + const auto thread_data_idx_begin = thread_cluster_idx * thread_single_load_size; + const auto wave_data_idx_begin = wave_cluster_idx * wave_single_load_size; SetSrcSliceOrigin(src_desc, src_block_slice_origin + thread_data_idx_begin); - SetDstSliceOrigin(dst_desc, dst_block_slice_origin + thread_data_idx_begin); + // We don't need threadwise offset for lds since it was calculate by HW + // We still need input the wavewise offset. + SetDstSliceOrigin(dst_desc, dst_block_slice_origin + wave_data_idx_begin); } __device__ void SetSrcSliceOrigin(const SrcDesc& src_desc, const Index& src_slice_origin_idx) @@ -215,7 +237,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad // Loop over the destination block and copy data. static_ford{}([&](auto ordered_dst_access_idx) { const auto src_offset = src_coord_.GetOffset(); - const auto dst_offset = dst_coord_.GetOffset(); + const auto dst_offset = __builtin_amdgcn_readfirstlane(dst_coord_.GetOffset()); // Check if src data is not in the logic padding area. const bool is_src_valid = @@ -303,7 +325,10 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad } private: - static constexpr auto thread_cluster_desc_ = make_cluster_descriptor(ThreadClusterLengths{}); + static constexpr auto thread_cluster_desc_ = + make_cluster_descriptor(ThreadClusterLengths{}, ThreadClusterArrangeOrder{}); + static constexpr auto wave_cluster_desc_ = + make_cluster_descriptor(wave_cluster_lengths, ThreadClusterArrangeOrder{}); SrcCoord src_coord_; DstCoord dst_coord_; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index ebed17cfb0..d48e4dca0f 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -299,119 +299,43 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX 1) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - // Tail number could be Odd or Even + return Tuple>{}; else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { -#if 1 - if(arg.KBatch > 1) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } -#endif - } + return Tuple, constant>{}; else { - throw std::runtime_error("wrong! BlkGemmPipelineVer"); + static_assert(false, "Unexpected BlkGemmPipelineVer!"); } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) + }(); + const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split); + using BoolChoices = Tuple; + static_for_product>{}( + [&](auto mainloop_choice, auto KBatch_cond_choice, auto tail_num_choice) { + constexpr auto CGlobalMemoryDataOperation = + KBatch_cond_choice.value ? InMemoryDataOperationEnum::AtomicAdd + : InMemoryDataOperationEnum::Set; + if(mainloop_choice.value == has_main_k_block_loop && + KBatch_cond_choice.value == (arg.KBatch > 1) && + tail_num_choice.value == tail_num) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< // + GridwiseGemm, + mainloop_choice.value, + CGlobalMemoryDataOperation, + minimum_occupancy, + tail_num_choice.value>; Run(kernel); } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); - } - } - } - + }); return ave_time; } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp index 8e9299ff52..89cafe2fca 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp @@ -14,6 +14,7 @@ #include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp" #include "ck/utility/common_header.hpp" #include "ck/utility/env.hpp" +#include "ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp" namespace ck { @@ -76,9 +77,10 @@ __global__ void GridwiseGemm::template Run_2Lds( karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_a_scale_grid + splitk_batch_offset.a_scale_k_split_offset, karg.p_b_grid + splitk_batch_offset.b_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.b_scale_k_split_offset, karg.p_c_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, p_shared_0, p_shared_1, karg); @@ -198,7 +200,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 NPerXdl, ComputeTypeB, is_single_rate_mfma, - is_scale_mfma>::selected_mfma.k_per_blk/APackedSize); + is_scale_mfma>::selected_mfma.k_per_blk / + APackedSize); using ThisThreadBlock = ThisThreadBlock; @@ -265,10 +268,18 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 __host__ __device__ static constexpr auto MakeGemmMmaTileDescriptor(const TileDesc_K0_MN_K1&) { constexpr index_t K0 = TileDesc_K0_MN_K1{}.GetLength(Number<0>{}); + constexpr index_t MN = TileDesc_K0_MN_K1{}.GetLength(Number<1>{}); constexpr index_t K1 = TileDesc_K0_MN_K1{}.GetLength(Number<2>{}); - return transform_tensor_descriptor( + constexpr auto permuted_desc = transform_tensor_descriptor( TileDesc_K0_MN_K1{}, + make_tuple(make_xor_with_modulo_transform(make_tuple(Number{}, Number{})), + make_pass_through_transform(Number{})), + make_tuple(Sequence<1, 0>{}, Sequence<2>{}), + make_tuple(Sequence<1, 0>{}, Sequence<2>{})); + + return transform_tensor_descriptor( + permuted_desc, make_tuple(make_merge_transform_v3_division_mod(make_tuple(Number{}, Number{})), make_unmerge_transform(make_tuple(Number{}, Number{}, @@ -351,12 +362,29 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // not pad M or K const auto a_grid_desc_ak0_m_ak1 = transform_tensor_descriptor( a_grid_desc_mraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(AK0, AK1Value)), + make_tuple(make_unmerge_transform(make_tuple(K / KPerBlock, AK0Number, AK1Value)), make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return a_grid_desc_ak0_m_ak1; + const auto a_grid_desc_permuted = transform_tensor_descriptor( + a_grid_desc_ak0_m_ak1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(M, AK0Number)), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto a_grid_desc = transform_tensor_descriptor( + a_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, AK0Number)), + make_pass_through_transform(M), + make_pass_through_transform(AK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return a_grid_desc; } } @@ -442,12 +470,30 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // not pad N or K const auto b_grid_desc_bk0_n_bk1 = transform_tensor_descriptor( b_grid_desc_nraw_kraw, - make_tuple(make_unmerge_transform(make_tuple(BK0, BK1Value)), - make_pass_through_transform(N)), + make_tuple( + make_unmerge_transform(make_tuple(K / KPerBlock, BK0Number, BK1Value)), + make_pass_through_transform(N)), make_tuple(Sequence<1>{}, Sequence<0>{}), - make_tuple(Sequence<0, 2>{}, Sequence<1>{})); + make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); - return b_grid_desc_bk0_n_bk1; + const auto b_grid_desc_permuted = transform_tensor_descriptor( + b_grid_desc_bk0_n_bk1, + make_tuple(make_pass_through_transform(K / KPerBlock), + make_xor_with_modulo_transform(make_tuple(N, BK0Number)), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<2, 1>{}, Sequence<3>{})); + + const auto b_grid_desc = transform_tensor_descriptor( + b_grid_desc_permuted, + make_tuple( + make_merge_transform_v3_division_mod(make_tuple(K / KPerBlock, BK0Number)), + make_pass_through_transform(N), + make_pass_through_transform(BK1Value)), + make_tuple(Sequence<0, 1>{}, Sequence<2>{}, Sequence<3>{}), + make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{})); + + return b_grid_desc; } else { @@ -648,10 +694,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 bool is_reduce_ = false) : Problem{M_, N_, - K_/APackedSize, - StrideA_/APackedSize, + K_ / APackedSize, + StrideA_ / APackedSize, StrideScaleA_, - StrideB_/BPackedSize, + StrideB_ / BPackedSize, StrideScaleB_, StrideC_, k_batch_}, @@ -723,21 +769,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // Calculate A scale offset if constexpr(is_same_v) { - a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/APackedSize); + a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / APackedSize); } else if constexpr(is_same_v) { - a_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/APackedSize) * karg.StrideScaleA; + a_scale_k_split_offset = + k_id * karg.KRead / (ScaleBlockSize / APackedSize) * karg.StrideScaleA; } // Calculate B scale offset if constexpr(is_same_v) { - b_scale_k_split_offset = k_id * (karg.KRead / (ScaleBlockSize/BPackedSize)) * karg.StrideScaleB; + b_scale_k_split_offset = + k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; } else if constexpr(is_same_v) { - b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize/BPackedSize); + b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); } if(k_id < (karg.KBatch - 1)) @@ -771,9 +819,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // A matrix in LDS memory, dst of blockwise copy if constexpr(ABlockLdsExtraM || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // contiguous in LDS return make_naive_tensor_descriptor( make_tuple(AK0Number, Number{}, AK1Number), - make_tuple(AK1Number, Number{}, I1)); + make_tuple(AK1Number, Number{}, I1)); } // xor tensor transformation request more unnecessary vgpr usage, would cause register spill // in some cases. @@ -888,9 +937,10 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // B matrix in LDS memory, dst of blockwise copy if constexpr(BBlockLdsExtraN || BlkGemmPipelineVer == BlockGemmPipelineVersion::v4) { + // contiguous in lds return make_naive_tensor_descriptor( make_tuple(BK0Number, Number{}, BK1Number), - make_tuple(BK1Number, Number{}, I1)); + make_tuple(BK1Number, Number{}, I1)); } else if constexpr(is_same::value) { @@ -1074,7 +1124,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 (NPerBlock % (NXdlPerWave * NPerXdl)) == 0, "Invalid tuning param!"); - static_assert(KPerBlock % (ScaleBlockSize/BPackedSize) == 0, + static_assert(KPerBlock % (ScaleBlockSize / BPackedSize) == 0, "KPerBlock should be multiple of ScaleBlockSize"); if constexpr(!(GemmSpec == tensor_operation::device::GemmSpecialization::MPadding || @@ -1381,67 +1431,42 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -1449,12 +1474,11 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // Cast after lds auto a_block_buf = make_dynamic_buffer( - static_cast(p_shared), - a_block_desc_ak0_m_ak1.GetElementSpaceSize()); + static_cast(p_shared), a_block_desc_ak0_m_ak1.GetElementSpaceSize()); auto b_block_buf = make_dynamic_buffer( - reinterpret_cast(static_cast(p_shared) + a_block_space_size_aligned * - sizeof(ADataType)), + reinterpret_cast(static_cast(p_shared) + + a_block_space_size_aligned * sizeof(ADataType)), b_block_desc_bk0_n_bk1.GetElementSpaceSize()); constexpr auto a_block_slice_copy_step = make_multi_index(KPerBlock / AK1Number, 0, 0); @@ -1556,7 +1580,6 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // shuffle C and write out { - // printf("c_thread_buf %f %f\n", c_thread_buf[I0], c_thread_buf[I1]); static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); @@ -1801,15 +1824,17 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // A/B shuffled scale for better 8-bit scale access pattern // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack - const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed(make_tuple( - problem.M / (MXdlPack * MPerXdl), - math::integer_divide_ceil(problem.K, (ScaleBlockSize/APackedSize)) / (KXdlPack * 64 / MPerXdl), - 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed(make_tuple( - problem.N / (NXdlPack * NPerXdl), - math::integer_divide_ceil(problem.K, (ScaleBlockSize/BPackedSize)) / (KXdlPack * 64 / NPerXdl), - 64 * KXdlPack * NXdlPack / scale_pack_size_b)); + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); Run( p_a_grid, a_grid_desc_ak0_m_ak1.GetElementSpaceSize()); const auto b_grid_buf = make_dynamic_buffer( @@ -1868,12 +1887,17 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 auto c_grid_buf = make_dynamic_buffer( p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize()); + // A Scale buffer + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); + // B Scale buffer const auto b_scale_grid_buf = make_dynamic_buffer( p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); - const AElementwiseOperation a_element_op{}; - const BElementwiseOperation b_element_op{}; + static_assert( + is_same_v && + is_same_v); const CElementwiseOperation c_element_op{}; // divide block work by [M, N] @@ -1909,67 +1933,42 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // B matrix in LDS memory, dst of blockwise copy constexpr auto b_block_desc_bk0_n_bk1 = GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1(); - // A matrix blockwise copy auto a_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ADataType, - ADataType, - decltype(a_grid_desc_ak0_m_ak1), - decltype(a_block_desc_ak0_m_ak1), - ABlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - ABlockTransferSrcVectorDim, - 2, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - 1, - 1, - AThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ADataType, + ADataType, + decltype(a_grid_desc_ak0_m_ak1), + decltype(a_block_desc_ak0_m_ak1), + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + 2, + ABlockTransferSrcScalarPerVector>( a_grid_desc_ak0_m_ak1, make_multi_index(0, m_block_data_idx_on_grid, 0), - a_element_op, a_block_desc_ak0_m_ak1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // B matrix blockwise copy auto b_blockwise_copy = - ThreadGroupTensorSliceTransfer_v4r1, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BDataType, - BDataType, - decltype(b_grid_desc_bk0_n_bk1), - decltype(b_block_desc_bk0_n_bk1), - BBlockTransferSrcAccessOrder, - Sequence<0, 1, 2>, - BBlockTransferSrcVectorDim, - 2, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - 1, - 1, - BThreadTransferSrcResetCoordinateAfterRun, - true, - BlockwiseGemmPipe::GlobalBufferNum>( + ThreadGroupTensorSliceTransfer_DirectLoad, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BDataType, + BDataType, + decltype(b_grid_desc_bk0_n_bk1), + decltype(b_block_desc_bk0_n_bk1), + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + 2, + BBlockTransferSrcScalarPerVector>( b_grid_desc_bk0_n_bk1, make_multi_index(0, n_block_data_idx_on_grid, 0), - b_element_op, b_block_desc_bk0_n_bk1, - make_multi_index(0, 0, 0), - ck::tensor_operation::element_wise::PassThrough{}); + make_multi_index(0, 0, 0)); // LDS allocation for A and B: be careful of alignment constexpr auto a_block_space_size_aligned = math::integer_least_multiple( @@ -2006,76 +2005,99 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 (a_grid_desc_ak0_m_ak1.GetLength(I0) * a_grid_desc_ak0_m_ak1.GetLength(I2)) / KPerBlock); - // B scale - static constexpr auto mfma = - MfmaSelector{}; - static constexpr auto KPerXdlops = mfma.GetKPerXdlops(); - static constexpr auto K1PerXdlops = mfma.GetK1PerXdlops(); - static constexpr auto K0PerXdlops = KPerXdlops / K1PerXdlops; - static constexpr auto KPerThread = KPerBlock / K0PerXdlops; + // Initial thread mapping for: + // BlockSize = 256 + // MPerXdl=NPerXdl=32 and MPerBlock=NPerBlock=128 MRepeat=NRepeat=2 MWaves=NWaves=2 + // For each [m0, n0] tile, there are 4 waves: + // tId in [ 0, 63] m x n = [ 0, 31] x [ 0, 31] waveId = [0, 0] + // tId in [ 64, 127] m x n = [ 0, 31] x [32, 63] waveId = [0, 1] + // tId in [128, 191] m x n = [32, 63] x [ 0, 31] waveId = [1, 0] + // tId in [192, 255] m x n = [32, 63] x [32, 63] waveId = [1, 1] - const index_t ScaleSliceSizeN = NXdlPerWave; - static constexpr auto ScaleSliceSizeK = (KPerThread + (ScaleBlockSize/BPackedSize) - 1) / (ScaleBlockSize/BPackedSize); - static constexpr auto KBlockScaleSliceSizeK = - (KPerBlock + (ScaleBlockSize/BPackedSize) - 1) / (ScaleBlockSize/BPackedSize); + // BlockSize = 128 + // MPerXdl=NPerXdl=16 and MPerBlock=128 NPerBlock=16 MRepeat=4 NRepeat=1 MWaves=2 NWaves=1 + // For each [m0, n0] tile, there are 2 waves: + // tId in [ 0, 63] m x n = [ 0, 15] x [0, 15] waveId = [0, 0] + // tId in [ 64, 127] m x n = [16, 31] x [0, 15] waveId = [1, 0] - constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); + // TODO: Document initial thread mapping for more combinations of parameters - constexpr index_t NWaves = NPerBlock / (NXdlPerWave * NPerXdl); + const auto wave_idx = BlockwiseGemmPipe::GetWaveIdx(); + const auto waveId_m = wave_idx[I0]; + const auto waveId_n = wave_idx[I1]; - auto b_thread_offset_n = - get_thread_local_1d_id() % NPerXdl + - (get_thread_local_1d_id() / BlockwiseGemmPipe::WaveSize) % NWaves * NPerXdl; - auto b_thread_offset_k = - (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / NPerXdl * KPerThread; + // static constexpr auto mfma = BlockwiseGemmPipe::xdlops_gemm.mfma; - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0, 1>, - 1, - ScaleSliceSizeK, - 1, - false>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock + b_thread_offset_n, - b_thread_offset_k / (ScaleBlockSize/BPackedSize))); + // auto thread_offset_k = (get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize) / + // mfma.selected_mfma.num_threads_per_blk; - constexpr auto b_scale_thread_slice_copy_step = - make_tuple(make_multi_index(NWaves * NPerXdl, 0), - make_multi_index(-NPerBlock, 0), - make_multi_index(-NPerBlock, KBlockScaleSliceSizeK)); + // A wave access continuous memory + auto thread_offset_shuffled = + get_thread_local_1d_id() % BlockwiseGemmPipe::WaveSize * KXdlPack * MXdlPack; - blockwise_gemm_pipeline.template Run( - a_grid_desc_ak0_m_ak1, - a_block_desc_ak0_m_ak1, - a_blockwise_copy, - a_grid_buf, - a_block_bufs, - a_block_slice_copy_step, - b_grid_desc_bk0_n_bk1, - b_block_desc_bk0_n_bk1, - b_blockwise_copy, - b_grid_buf, - b_block_bufs, - b_block_slice_copy_step, - c_thread_buf, - b_scale_grid_desc_bn_ak, - b_scale_thread_desc, - b_scale_thread_copy, - b_scale_grid_buf, - b_scale_thread_slice_copy_step, - num_k_block_main_loop); + auto a_thread_offset_m = waveId_m; + + auto a_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + AScaleDataType, + AScaleDataType, + decltype(a_scale_grid_desc_am_ak), + decltype(BlockwiseGemmPipe::a_scale_thread_desc), + Sequence<1, 1, KXdlPack * MXdlPack / scale_pack_size_a>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_a, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / MPerXdl / MXdlPack + a_thread_offset_m, + 0, + thread_offset_shuffled / scale_pack_size_a)); + + auto b_thread_offset_n = waveId_n; + + auto b_scale_thread_copy = ThreadwiseTensorSliceTransfer_v2< + BScaleDataType, + BScaleDataType, + decltype(b_scale_grid_desc_bn_ak), + decltype(BlockwiseGemmPipe::b_scale_thread_desc), + Sequence<1, 1, KXdlPack * NXdlPack / scale_pack_size_b>, // SliceLengths + Sequence<0, 1, 2>, // DimAccessOrder + 2, // SrcVectorDim + KXdlPack * MXdlPack / scale_pack_size_b, // SrcScalarPerVector + 1, // SrcScalarStrideInVector + true>(b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / NPerXdl / NXdlPack + b_thread_offset_n, + 0, + thread_offset_shuffled / scale_pack_size_b)); + + blockwise_gemm_pipeline.template Run(a_grid_desc_ak0_m_ak1, + a_block_desc_ak0_m_ak1, + a_blockwise_copy, + a_grid_buf, + a_block_bufs, + a_block_slice_copy_step, + b_grid_desc_bk0_n_bk1, + b_block_desc_bk0_n_bk1, + b_blockwise_copy, + b_grid_buf, + b_block_bufs, + b_block_slice_copy_step, + c_thread_buf, + a_scale_grid_desc_am_ak, + a_scale_thread_copy, + a_scale_grid_buf, + b_scale_grid_desc_bn_ak, + b_scale_thread_copy, + b_scale_grid_buf, + num_k_block_main_loop); // shuffle C and write out { static_assert(MXdlPerWave % CShuffleMXdlPerWavePerShuffle == 0 && NXdlPerWave % CShuffleNXdlPerWavePerShuffle == 0, "wrong!"); + static_assert(CShuffleMXdlPerWavePerShuffle % MXdlPack == 0 && + CShuffleNXdlPerWavePerShuffle % NXdlPack == 0, + "wrong!"); constexpr index_t MWave = MPerBlock / (MXdlPerWave * MPerXdl); constexpr index_t NWave = NPerBlock / (NXdlPerWave * NPerXdl); @@ -2087,16 +2109,18 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // TODO: hacky, fix it! // c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp is only used to get lengths constexpr auto c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp = - blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(); + blockwise_gemm_pipeline.GetCBlockDescriptor_M0_N0_M1_N1_M2_N2_M3_M4_M5_N3(); constexpr auto M0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I0); constexpr auto N0 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I1); constexpr auto M1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I2); constexpr auto N1 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I3); constexpr auto M2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I4); - constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); - constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); - constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto N2 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I5); + constexpr auto M3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I6); + constexpr auto M4 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I7); + constexpr auto M5 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I8); + constexpr auto N3 = c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2_tmp.GetLength(I9); constexpr auto c_shuffle_block_desc_mblock_mperblock_nblock_nperblock = GetCShuffleBlockDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(); @@ -2110,19 +2134,25 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 make_tuple( make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // M0 (MXdlPerWave) per shuffle - M1, // M1 = MWave - M2, // M2 * M3 * M4 = MPerXdl - M3, - M4)), + Number{}, // M0 (MXdlPerWave) per + // shuffle + M1, // M1 = MWave + M2, // M2 = MXdlPack + M3, // M3 * M4 * M5 = MPerXdl + M4, + M5)), make_freeze_transform(I0), make_unmerge_transform(make_tuple( - Number{}, // N0 (NXdlPerWave) per shuffle - N1, // N1 = NWave - N2))), // N2 = NPerXdl + Number{}, // N0 (NXdlPerWave) per + // shuffle + N1, // N1 = NWave + N2, // N2 = NXdlPack + N3))), // N3 = NPerXdl make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}), - make_tuple( - Sequence<>{}, Sequence<0, 2, 4, 5, 6>{}, Sequence<>{}, Sequence<1, 3, 7>{})); + make_tuple(Sequence<>{}, + Sequence<0, 2, 4, 6, 7, 8>{}, + Sequence<>{}, + Sequence<1, 3, 5, 9>{})); // calculate origin of thread output tensor on global memory // blockwise GEMM c matrix starting index @@ -2134,8 +2164,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto m_thread_data_on_block_to_m0_m1_m2_m3_m4_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4))), - make_tuple(Sequence<0, 1, 2, 3, 4>{}), + make_tuple(make_merge_transform(make_tuple(M0, M1, M2, M3, M4, M5))), + make_tuple(Sequence<0, 1, 2, 3, 4, 5>{}), make_tuple(Sequence<0>{})); const auto m_thread_data_on_block_idx = @@ -2144,8 +2174,8 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 const auto n_thread_data_on_block_to_n0_n1_n2_adaptor = make_single_stage_tensor_adaptor( - make_tuple(make_merge_transform(make_tuple(N0, N1, N2))), - make_tuple(Sequence<0, 1, 2>{}), + make_tuple(make_merge_transform(make_tuple(N0, N1, N2, N3))), + make_tuple(Sequence<0, 1, 2, 3>{}), make_tuple(Sequence<0>{})); const auto n_thread_data_on_block_idx = @@ -2153,36 +2183,39 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 make_multi_index(n_thread_data_on_block)); // shuffle: threadwise copy C from VGPR to LDS - auto c_thread_copy_vgpr_to_lds = - ThreadwiseTensorSliceTransfer_v1r3, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - 7, - 1, - InMemoryDataOperationEnum::Set, - 1, - true>{ - c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, - make_multi_index(0, - 0, - m_thread_data_on_block_idx[I1], - n_thread_data_on_block_idx[I1], - m_thread_data_on_block_idx[I2], - m_thread_data_on_block_idx[I3], - m_thread_data_on_block_idx[I4], - n_thread_data_on_block_idx[I2]), - ck::tensor_operation::element_wise::PassThrough{}}; + auto c_thread_copy_vgpr_to_lds = ThreadwiseTensorSliceTransfer_v1r3< + AccDataType, + CShuffleDataType, + decltype(c_thread_desc_m0_n0_m1_n1_m2_m3_m4_n2), + decltype(c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2), + ck::tensor_operation::element_wise::PassThrough, + Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + 9, + 1, + InMemoryDataOperationEnum::Set, + 1, + true>{c_block_desc_m0_n0_m1_n1_m2_m3_m4_n2, + make_multi_index(0, + 0, + m_thread_data_on_block_idx[I1], + n_thread_data_on_block_idx[I1], + m_thread_data_on_block_idx[I2], + n_thread_data_on_block_idx[I2], + m_thread_data_on_block_idx[I3], + m_thread_data_on_block_idx[I4], + m_thread_data_on_block_idx[I5], + n_thread_data_on_block_idx[I3]), + ck::tensor_operation::element_wise::PassThrough{}}; // shuffle: blockwise copy C from LDS to global auto c_shuffle_block_copy_lds_to_global = ThreadGroupTensorSliceTransfer_v6r1< @@ -2212,12 +2245,23 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 // space filling curve for threadwise C in VGPR constexpr auto sfc_c_vgpr = - SpaceFillingCurve, - Sequence<0, 1, 2, 3, 4, 5, 6, 7>, - Sequence, + Sequence<0, 1, 2, 3, 4, 5, 6, 7, 8, 9>, + Sequence __device__ static void Run_2Lds(const ADataType* p_a_grid, + const AScaleDataType* p_a_scale_grid, const BDataType* p_b_grid, const BScaleDataType* p_b_scale_grid, CDataType* p_c_grid, @@ -2286,22 +2331,33 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 problem.K, problem.KPadded, problem.N, problem.NPadded, problem.StrideB, problem.BK0); const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( c_grid_desc_m_n, problem.MBlock, problem.NBlock); - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(problem.N, math::integer_divide_ceil(problem.K, ScaleBlockSize/BPackedSize)), - make_tuple(problem.StrideScaleB, 1)); + // A/B shuffled scale for better 8-bit scale access pattern + // MNRepeat -> KRepeat -> KThreadPerXdl -> MNThreadPerXdl -> KXdlPack -> MNXdlPack + const auto a_scale_grid_desc_am_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.M / (MXdlPack * MPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / APackedSize)) / + (KXdlPack * 64 / MPerXdl), + 64 * KXdlPack * MXdlPack / scale_pack_size_a)); + + const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor_packed( + make_tuple(problem.N / (NXdlPack * NPerXdl), + math::integer_divide_ceil(problem.K, (ScaleBlockSize / BPackedSize)) / + (KXdlPack * 64 / NPerXdl), + 64 * KXdlPack * NXdlPack / scale_pack_size_b)); Run_2Lds(p_a_grid, + p_a_scale_grid, p_b_grid, p_b_scale_grid, p_c_grid, @@ -2309,6 +2365,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3 p_shared_1, problem, a_grid_desc_ak0_m_ak1, + a_scale_grid_desc_am_ak, b_grid_desc_bk0_n_bk1, b_scale_grid_desc_bn_ak, c_grid_desc_mblock_mperblock_nblock_nperblock); diff --git a/include/ck/utility/amd_buffer_addressing.hpp b/include/ck/utility/amd_buffer_addressing.hpp index 32a1329448..1b9865ad1e 100644 --- a/include/ck/utility/amd_buffer_addressing.hpp +++ b/include/ck/utility/amd_buffer_addressing.hpp @@ -1022,7 +1022,12 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, // Direct loads require that each thread reads and writes exactly a single DWORD. constexpr auto dword_bytes = 4; constexpr auto bytes_per_thread = sizeof(T) * NumElemsPerThread; +#if defined(__gfx950__) + static_assert(bytes_per_thread == dword_bytes || bytes_per_thread == dword_bytes * 3 || + bytes_per_thread == dword_bytes * 4); +#else static_assert(bytes_per_thread == dword_bytes); +#endif #ifndef CK_CODE_GEN_RTC const uint32_t* global_ptr = @@ -1059,7 +1064,7 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr, #endif llvm_amdgcn_raw_buffer_load_lds( - src_resource, lds_ptr, sizeof(uint32_t), global_offset_bytes, 0, 0, 0); + src_resource, lds_ptr, bytes_per_thread, global_offset_bytes, 0, 0, 0); #endif } #endif diff --git a/include/ck/utility/functional2.hpp b/include/ck/utility/functional2.hpp index a11963cb47..3d2c84727c 100644 --- a/include/ck/utility/functional2.hpp +++ b/include/ck/utility/functional2.hpp @@ -1,10 +1,11 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck/utility/functional.hpp" #include "ck/utility/sequence.hpp" +#include "ck/utility/tuple.hpp" namespace ck { @@ -70,4 +71,35 @@ struct static_for<0, N, 1> : detail::make_applier using detail::make_applier::operator(); }; +template +struct static_for_range +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + // tweak -fbracket-depth if compilation fails. Clang default limit is 256 + (f(Is{}), ...); + } +}; + +template +struct static_for_product; +template +struct static_for_product> : public static_for_range +{ +}; +template +struct static_for_product, Rest...> +{ + template + __host__ __device__ constexpr void operator()(F f) const + { + static_for_product>{}([&](auto i0) { // + static_for_product{}([&](auto... is) { // + f(i0, is...); + }); + }); + } +}; + } // namespace ck diff --git a/include/ck/utility/integral_constant.hpp b/include/ck/utility/integral_constant.hpp index 75f35d762c..a7fa64d710 100644 --- a/include/ck/utility/integral_constant.hpp +++ b/include/ck/utility/integral_constant.hpp @@ -5,14 +5,22 @@ namespace ck { +template +struct constant +{ + using value_type = decltype(v); + using type = constant; // using injected-class-name + static constexpr value_type value = v; + __host__ __device__ constexpr operator value_type() const noexcept { return value; } + __host__ __device__ constexpr value_type operator()() const noexcept { return value; } +}; + template -struct integral_constant +struct integral_constant : constant { static constexpr T value = v; typedef T value_type; typedef integral_constant type; - __host__ __device__ constexpr operator value_type() const noexcept { return value; } - __host__ __device__ constexpr value_type operator()() const noexcept { return value; } }; template diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp index a2732a20b4..48b38524ea 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn.hpp @@ -44,17 +44,18 @@ using device_gemm_mx_xdl_f4_f4_f16_mk_nk_mn_instances = std::tuple< //#############################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| //#############################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| //#############################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 32, 256, 128, 16, 16, 16, 16, 2, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 64, 256, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 128, 128, 16, 16, 16, 16, 6, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, + // DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 96, 256, 128, 16, 16, 16, 16, 6, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, false, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 256, 128, 16, 16, 16, 16, 8, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 256, 128, 16, 16, 16, 16, 4, 8, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 256, 128, 128, 16, 16, 16, 16, 8, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 256, 128, 128, 128, 16, 16, 16, 16, 4, 4, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3< Row, Col, Row, F4, E8M0PK, F4, E8M0PK, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, ScaleBlockSize, 64, 32, 32, 128, 16, 16, 16, 16, 2, 2, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 16, 1, 4>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, std::nullptr_t // clang-format on >; diff --git a/profiler/include/profiler/profile_gemm_mx_impl.hpp b/profiler/include/profiler/profile_gemm_mx_impl.hpp index 4b7a5e9834..9c980bbf51 100644 --- a/profiler/include/profiler/profile_gemm_mx_impl.hpp +++ b/profiler/include/profiler/profile_gemm_mx_impl.hpp @@ -213,8 +213,7 @@ bool profile_gemm_mx_impl(int do_verification, default: a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - a_m_k_scale.GenerateTensorValue( - GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); // R[2^-125, 1] + a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f});