diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 29d1cafc6a..0205bf2668 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -10,7 +10,6 @@ #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" @@ -24,8 +23,9 @@ template using S = ck::Sequence; -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; +using MFMA = ck::tensor_layout::gemm::MFMA; using PassThrough = ck::tensor_operation::element_wise::PassThrough; @@ -202,10 +202,11 @@ template + ck::index_t ScaleBlockSize> bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { + constexpr bool BPreShuffle = ck::is_same_v; + using BRefLayout = ck::conditional_t; auto M = problem_size.M; auto N = problem_size.N; @@ -257,11 +258,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); auto b_k_n = - std::make_shared>(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + std::make_shared>(f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); auto b_input = b_k_n; if constexpr(BPreShuffle) b_input = std::make_shared>( - f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size + f_host_tensor_descriptor(K, N, StrideB, BRefLayout{})); // use layout only for size // scales for A and B Tensor a_m_k_scale(f_host_tensor_descriptor( @@ -350,7 +351,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c a_shuffled_scale.mData.data(), Scale_Padded_M, K / ScaleBlockSize); - preShuffleScaleBuffer>( + preShuffleScaleBuffer>( b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); if constexpr(BPreShuffle) { @@ -572,8 +573,7 @@ template + ck::index_t MXVectorSize> bool run_mx_gemm_example(int argc, char* argv[]) { ProblemSizeSplitK problem_size; @@ -594,6 +594,5 @@ bool run_mx_gemm_example(int argc, char* argv[]) CElementOp, AccDataType, CShuffleDataType, - MXVectorSize, - BPreShuffle>(problem_size, config); + MXVectorSize>(problem_size, config); } diff --git a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp index d458f02e65..562b2fdb17 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp @@ -16,7 +16,7 @@ using AccDataType = float; using CShuffleDataType = CDataType; using ALayout = Row; -using BLayout = Col; +using BLayout = MFMA; using CLayout = Row; using AElementOp = PassThrough; // elementwise transformation for A matrix @@ -33,7 +33,7 @@ constexpr auto BlkGemmPVer = ck::BlockGemmPipelineVersion::v3; // AB DataType: f4x2_pk_t // Mathmatically, all numbers are represented as f4x2. -using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle< +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMX_Xdl_CShuffleV3< ALayout, // ALayout BLayout, // BLayout CLayout, // CLayout @@ -99,8 +99,7 @@ int main(int argc, char* argv[]) CElementOp, AccDataType, CShuffleDataType, - ScaleBlockSize, - true>(argc, argv) + ScaleBlockSize>(argc, argv) ? 0 : -1; } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp index d088d3775d..ed168195ec 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp @@ -15,6 +15,7 @@ #include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" @@ -162,56 +163,108 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX { // GridwiseGemm - using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3< - ALayout, - BLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB>; + using GridwiseGemm = conditional_t< // + !is_same_v, + GridwiseGemmMX_xdl_cshuffle_v3< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>, + GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< + ALayout, + BLayout, + CLayout, + ADataType, + AScaleDataType, + BDataType, + BScaleDataType, + GemmAccDataType, + CShuffleDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + ScaleBlockSize, + BlockSize, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerXDL, + NPerXDL, + MXdlPerWave, + NXdlPerWave, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMXdlPerWavePerShuffle, + CShuffleNXdlPerWavePerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVector_NPerBlock, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>>; using Argument = typename GridwiseGemm::Argument; @@ -310,9 +363,15 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX, constant>{}; else - { static_assert(false, "Unexpected BlkGemmPipelineVer!"); - } + }(); + constexpr bool Use2LDS = []() { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + return false; + else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + return true; + else + static_assert(false, "Unexpected BlkGemmPipelineVer!"); }(); const TailNumber tail_num = GridwiseGemm::CalculateKBlockLoopTailNum(K_split); using BoolChoices = Tuple; @@ -327,31 +386,14 @@ struct DeviceGemmMX_Xdl_CShuffleV3 : public DeviceGemmMX 1) && tail_num_choice.value == tail_num) { - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3< // - GridwiseGemm, - mainloop_choice.value, - CGlobalMemoryDataOperation, - minimum_occupancy, - tail_num_choice.value>; - Run(kernel); - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - - const auto kernel = kernel_gemm_xdl_cshuffle_v3_2lds< // - GridwiseGemm, - mainloop_choice.value, - CGlobalMemoryDataOperation, - minimum_occupancy, - tail_num_choice.value>; - Run(kernel); - } - else - { - static_assert(false, "Unexpected BlkGemmPipelineVer!"); - } + const auto kernel = kernel_gemm_xdl_cshuffle_v3_mx< // + Use2LDS, + GridwiseGemm, + mainloop_choice.value, + CGlobalMemoryDataOperation, + minimum_occupancy, + tail_num_choice.value>; + Run(kernel); } }); return ave_time; diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp deleted file mode 100644 index 8e2aef991d..0000000000 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ /dev/null @@ -1,638 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include - -#include "ck/utility/common_header.hpp" - -#include "ck/host_utility/flush_cache.hpp" -#include "ck/tensor_description/tensor_descriptor.hpp" -#include "ck/tensor_description/tensor_descriptor_helper.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/device_gemm_mx.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" -#include "ck/host_utility/device_prop.hpp" -#include "ck/host_utility/kernel_launch.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { - -// clang-format off -/** - * \brief WIP: Implements XDL CShuffle V3 GEMM for microscale-compliant data types - * - * This class is a work-in-progress implementation of the XDL CShuffle V3 GEMM for - * microscale-compliant data types. - * - * Assumptions: - * - A and B data types are compliant with the OCP Microscaling Formats (MX) Specification - * - Each scale applies to ScaleBlockSize elements in K direction - * - A scale matrix is a row-major - * - B scale matrix is a column-major - * - Scale data types must have get_exponent_value() specialization, whereas lowest 8 bits of the - * exponent will be interpreted as conventional biased Float32 exponent (E8M0) - * - * Tunable parameters. - * The CK instance includes a series of tunable template parameters to control the parallel - * granularity of the workload to achieve load balancing on different hardware platforms. These - * parameters include Block Size, M/N/K Per Block, M/N per XDL, AK1, BK1, etc. - * - Block Size determines the number of threads in the thread block. - * - M/N/K Per Block determines the size of tile that each thread block is responsible for - * calculating. - * - M/N Per XDL refers to M/N size for Instinct accelerator Matrix Fused Multiply Add (MFMA) - * instructions operating on a per-wavefront basis. - * - A/B K1 is related to the data type. It can be any value ranging from 1 to K Per Block. To - * achieve the optimal load/store performance, 128bit per load is suggested. In addition, the A/B - * loading parameters must be changed accordingly to match the A/B K1 value; otherwise, it will - * result in compilation errors. - * - * Conditions for achieving computational load balancing on different hardware platforms can vary. - * - * Serialized version of the algorithm: - * \code - * // E = A * B + C - * // Loop over E[MPerBlock,NPerBlock] tiles - * for(int mb = 0; mb < M; mb += MPerBlock){ - * for(int nb = 0; nb < N; nb += NPerBlock){ - * // initialize E[MPerBlock,NPerBlock] tile - * for(int mt = mb; mt < mb + MPerBlock; mt++){ - * for(int nt = nb; nt < nb + NPerBlock; nt++){ - * E[mt,nt] = C[mt,nt]; - * } - * } - * - * // multiply-accumulate per tile - * for(int kb = 0; kb < K; kb += KPerBlock){ - * for(int m0 = mb; m0 < mb + MPerBlock; m0 += MWaves * MPerXDL){ - * for(int n0 = nb; n0 < nb + NPerBlock; n0 += NWaves * NPerXDL){ - * for(int mw = m0; mw < m0 + MWaves * MPerXDL; mw += MPerXDL){ - * for(int nw = n0; nw < n0 + NWaves * NPerXDL; nw += NPerXDL){ - * for(int k0 = kb; k0 < kb + KPerBlock; k0 += mfma.num_input_blks*KPack){ - * // MFMA accumulation - * for(int k_pack = k0; k_pack < k0 + mfma.num_input_blks*KPack; k_pack += KPerXdlops){ - * // MFMA instruction - * for(int k_mfma = k_pack; k_mfma < k_pack + KPerXdlops; k_mfma += mfma.k_per_blk){ - * for(int m = mw; m < mw + MPerXDL; m++){ - * for(int n = nw; n < nw + NPerXDL; n++){ - * for(int k = k_mfma; k < k_mfma + mfma.k_per_blk; k++){ - * E[m,n] += A[m,k] * B[k,n]; - * } - * } - * } - * } - * } - * } - * } - * } - * } - * } - * } - * } - * } - * \endcode - * - */ -// clang-format on -template -struct DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmMX -{ - // GridwiseGemm - using GridwiseGemm = GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle< - ALayout, - BLayout, - CLayout, - ADataType, - AScaleDataType, - BDataType, - BScaleDataType, - GemmAccDataType, - CShuffleDataType, - CDataType, - AElementwiseOperation, - BElementwiseOperation, - CElementwiseOperation, - GemmSpec, - ScaleBlockSize, - BlockSize, - MPerBlock, - NPerBlock, - KPerBlock, - AK1, - BK1, - MPerXDL, - NPerXDL, - MXdlPerWave, - NXdlPerWave, - ABlockTransferThreadClusterLengths_AK0_M_AK1, - ABlockTransferThreadClusterArrangeOrder, - ABlockTransferSrcAccessOrder, - ABlockTransferSrcVectorDim, - ABlockTransferSrcScalarPerVector, - ABlockTransferDstScalarPerVector_AK1, - false, - ABlockLdsExtraM, - BBlockTransferThreadClusterLengths_BK0_N_BK1, - BBlockTransferThreadClusterArrangeOrder, - BBlockTransferSrcAccessOrder, - BBlockTransferSrcVectorDim, - BBlockTransferSrcScalarPerVector, - BBlockTransferDstScalarPerVector_BK1, - false, - BBlockLdsExtraN, - CShuffleMXdlPerWavePerShuffle, - CShuffleNXdlPerWavePerShuffle, - CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, - CShuffleBlockTransferScalarPerVector_NPerBlock, - BlkGemmPipeSched, - BlkGemmPipelineVer, - ComputeTypeA, - ComputeTypeB>; - - using Argument = typename GridwiseGemm::Argument; - - // Invoker - struct Invoker : public BaseInvoker - { - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideA, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideB, arg_.BK0); - - auto size_a_buffer = - a_grid_desc_ak0_m_ak1.GetElementSpaceSize() * sizeof(ADataType); - auto size_b_buffer = - b_grid_desc_bk0_n_bk1.GetElementSpaceSize() * sizeof(BDataType); - - ck::utility::RotatingMemWrapper rotating_mem( - arg_, stream_config.rotating_count, size_a_buffer, size_b_buffer); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg_.p_c_grid, - 0, - arg_.M * arg_.N * sizeof(CDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_); - } - else - { - if(arg.KBatch > 1) - hipGetErrorString(hipMemsetAsync(arg.p_c_grid, - 0, - arg.M * arg.N * sizeof(CDataType), - stream_config.stream_id_)); - - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); - } - }; - - // TODO: Check if this is the right algorithm for minimum_occupancy - constexpr index_t minimum_occupancy = - BlkGemmPipeSched == BlockGemmPipelineScheduler::Intrawave - ? (BlkGemmPipelineVer == BlockGemmPipelineVersion::v3 && - MPerBlock * NPerBlock * KPerBlock * sizeof(ADataType) <= 128 * 128 * 64 * 2) - ? 2 - : 1 - : 2; - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - } - // Tail number could be Odd or Even - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { -#if 0 - if(arg.KBatch > 1) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - } -#endif - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - else - { - throw std::runtime_error("wrong! BlkGemmPipelineVer"); - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle< - GridwiseGemm, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy>; - Run(kernel); - } - else - { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_b_preshuffle; - Run(kernel); - } - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; - - static constexpr bool IsValidCompilationParameter() - { - static_assert(is_scale_mfma_data_type() && is_scale_mfma_data_type(), - "Only microscaling formats are supported for ADataType and BDataType"); - - static_assert(ScaleBlockSize == 32, "Only ScaleBlockSize 32 is supported"); - - static_assert(is_same_v && is_same_v, - "ComputeTypeA and ComputeTypeB must be the same as ADataType and BDataType"); - - return true; - } - - static bool IsSupportedArgument(const Argument& arg) - { - if constexpr(!IsValidCompilationParameter()) - { - return false; - } - - if(ck::get_device_name() != "gfx950") - { - return false; - } - - if(!is_bf16_atomic_supported() && std::is_same_v && arg.KBatch > 1) - { - return false; - } - - if((arg.K % AK1 != 0 || arg.K % BK1 != 0) && !(GemmSpec == GemmSpecialization::MKPadding || - GemmSpec == GemmSpecialization::NKPadding || - GemmSpec == GemmSpecialization::MNKPadding || - GemmSpec == GemmSpecialization::KPadding)) - { - return false; - } - - return GridwiseGemm::CheckValidity(arg); - } - - // polymorphic - bool IsSupportedArgument(const BaseArgument* p_arg) override - { - return IsSupportedArgument(*dynamic_cast(p_arg)); - } - - static auto MakeArgument(const ADataType* p_a, - const AScaleDataType* p_a_scale, - const BDataType* p_b, - const BScaleDataType* p_b_scale, - CDataType* p_c, - index_t M, - index_t N, - index_t K, - index_t StrideA, - index_t StrideScaleA, - index_t StrideB, - index_t StrideScaleB, - index_t StrideC, - index_t KBatch, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) - { - return Argument{p_a, - p_a_scale, - p_b, - p_b_scale, - p_c, - M, - N, - K, - StrideA, - StrideScaleA, - StrideB, - StrideScaleB, - StrideC, - KBatch, - a_element_op, - b_element_op, - c_element_op}; - } - - static auto MakeInvoker() { return Invoker{}; } - - // polymorphic - std::unique_ptr MakeArgumentPointer(const void* p_a, - const void* p_a_scale, - const void* p_b, - const void* p_b_scale, - void* p_c, - ck::index_t M, - ck::index_t N, - ck::index_t K, - ck::index_t StrideA, - ck::index_t StrideScaleA, - ck::index_t StrideB, - ck::index_t StrideScaleB, - ck::index_t StrideC, - ck::index_t KBatch, - AElementwiseOperation a_element_op, - BElementwiseOperation b_element_op, - CElementwiseOperation c_element_op) override - { - return std::make_unique(static_cast(p_a), - static_cast(p_a_scale), - static_cast(p_b), - static_cast(p_b_scale), - static_cast(p_c), - M, - N, - K, - StrideA, - StrideScaleA, - StrideB, - StrideScaleB, - StrideC, - KBatch, - a_element_op, - b_element_op, - c_element_op); - } - - // polymorphic - std::unique_ptr MakeInvokerPointer() override - { - return std::make_unique(Invoker{}); - } - - // polymorphic - std::string GetTypeString() const override - { - auto str = std::stringstream(); - - std::map BlkGemmPipelineSchedulerToString{ - {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, - {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; - - std::map BlkGemmPipelineVersionToString{ - {BlockGemmPipelineVersion::v1, "v1"}, - {BlockGemmPipelineVersion::v2, "v2"}, - {BlockGemmPipelineVersion::v3, "v3"}, - {BlockGemmPipelineVersion::v4, "v4"}, - {BlockGemmPipelineVersion::v5, "v5"}}; - - // clang-format off - str << "DeviceGemmMX_Xdl_CShuffleV3" - << "<" - << getGemmSpecializationString(GemmSpec) << ", " - << std::string(ALayout::name)[0] - << std::string(BLayout::name)[0] - << std::string(CLayout::name)[0] - << ">" - << " BlkSize: " - << BlockSize << ", " - << "BlkTile: " - << MPerBlock<<"x"< -__global__ void +__global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -55,17 +58,18 @@ __global__ void #endif // end of if (defined(__gfx9__)) } -template -__global__ void +__global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { #if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ // Pass two lds pointer is the key to tell compiler that ds_read/write @@ -89,6 +93,7 @@ __global__ void ignore = karg; #endif // end of if (defined(__gfx9__)) } +#endif template -__global__ void +__global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); @@ -55,23 +58,25 @@ __global__ void #endif // end of if (defined(__gfx9__)) } -template -__global__ void +__global__ enable_if_t #if CK_USE_LAUNCH_BOUNDS __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_mx(typename GridwiseGemm::Argument karg) { -#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) +#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__ // Pass two lds pointer is the key to tell compiler that ds_read/write // operate on different lds chunk at same time without order dependecy __shared__ char p_shared_0[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_1[GridwiseGemm::GetSharedMemoryNumberOfByte()]; + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); GridwiseGemm::template Run_2Lds( @@ -88,6 +93,7 @@ __global__ void ignore = karg; #endif // end of if (defined(__gfx9__)) } +#endif template ::value) + else if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(N, K), make_tuple(StrideB, I1)); } @@ -796,7 +802,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle { b_k_split_offset = k_id * karg.KRead * karg.StrideB; } - else if constexpr(is_same_v) + else if constexpr(is_same_v) { if constexpr(!PermuteB) { @@ -826,7 +832,7 @@ struct GridwiseGemmMX_xdl_cshuffle_v3_bpreshuffle b_scale_k_split_offset = k_id * (karg.KRead / (ScaleBlockSize / BPackedSize)) * karg.StrideScaleB; } - else if constexpr(is_same_v) + else if constexpr(is_same_v) { b_scale_k_split_offset = k_id * karg.KRead / (ScaleBlockSize / BPackedSize); } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp index cd0b29ba1b..88f6238d8f 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx.hpp @@ -173,6 +173,73 @@ struct DeviceOperationInstanceFactory< } }; +void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances( + std::vector>>& instances); + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMX, + enable_if_t>> +{ + using DeviceOp = DeviceGemmMX; + + static auto GetInstances() + { + std::vector> op_ptrs; + + if constexpr(is_same_v && is_same_v && is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(op_ptrs); + } + } + + return op_ptrs; + } +}; } // namespace instance } // namespace device } // namespace tensor_operation diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx_wp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx_wp.hpp deleted file mode 100644 index 7e149d26ca..0000000000 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_mx_wp.hpp +++ /dev/null @@ -1,91 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include -#include -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" -#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" - -#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp" - -namespace ck { -namespace tensor_operation { -namespace device { -namespace instance { - -void add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances( - std::vector>>& instances); - -template -struct DeviceOperationInstanceFactory< - ck::tensor_operation::device::DeviceGemmMX, - enable_if_t>> -{ - using DeviceOp = DeviceGemmMX; - - static auto GetInstances() - { - std::vector> op_ptrs; - - if constexpr(is_same_v && is_same_v && is_same_v) - { - if constexpr(is_same_v && is_same_v && - is_same_v) - { - add_device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_default_instances(op_ptrs); - } - } - - return op_ptrs; - } -}; - -} // namespace instance -} // namespace device -} // namespace tensor_operation -} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp index b1bb922cd3..d0991a82e7 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp +++ b/library/src/tensor_operation_instance/gpu/gemm_mx/device_gemm_mx_xdl_f4_f4_f16/device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn.hpp @@ -4,7 +4,7 @@ #include "ck/ck.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" #include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" @@ -42,29 +42,29 @@ static constexpr auto ScaleBlockSize = 32; template using device_gemm_mx_xdl_f4_f4_f16_mk_mfma_mn_instances = std::tuple< // clang-format off - //#################################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| - //#################################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| - //#################################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| - //#################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + //#####################| ALayout| BLayout| CLayout|AData| AScale|BData| BScale| CData| AccData| Cshuffle| A| B| C| GEMM| Scale Block| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#####################| | | | Type| Data| Type| Data| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline| + //#####################| | | | | Type| | Type| | | | Operation| Operation| Operation| | | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision| + //#####################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, - // DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, + // DeviceGemmMX_Xdl_CShuffleV3, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 2, 2, S<1, 32, 1, 8>, 8, BlkGemmPipeSched, BlockGemmPipelineVersion::v3>, std::nullptr_t // clang-format on >; diff --git a/profiler/include/profiler/profile_gemm_mx_impl.hpp b/profiler/include/profiler/profile_gemm_mx_impl.hpp index feace09f2c..46650b82fb 100644 --- a/profiler/include/profiler/profile_gemm_mx_impl.hpp +++ b/profiler/include/profiler/profile_gemm_mx_impl.hpp @@ -10,7 +10,6 @@ #include "ck/ck.hpp" #include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" #include "ck/library/tensor_operation_instance/gpu/gemm_mx.hpp" -#include "ck/library/tensor_operation_instance/gpu/gemm_mx_wp.hpp" #include "ck/library/utility/check_err.hpp" #include "ck/library/utility/device_memory.hpp" #include "ck/library/utility/fill.hpp" @@ -18,7 +17,6 @@ #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/library/utility/literals.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" #include "ck/utility/data_type.hpp"