diff --git a/example/67_gemm_microscaling/gemm_mx_bpreshuffle_common.hpp b/example/67_gemm_microscaling/gemm_mx_bpreshuffle_common.hpp deleted file mode 100644 index 9c373611bc..0000000000 --- a/example/67_gemm_microscaling/gemm_mx_bpreshuffle_common.hpp +++ /dev/null @@ -1,618 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include - -#include "ck/ck.hpp" -#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" -#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" -#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" -#include "ck/library/utility/host_tensor_generator.hpp" -#include "ck/utility/blkgemmpipe_scheduler.hpp" -#include "ck/utility/data_type.hpp" -#include "ck/utility/sequence.hpp" -#include "ck/library/reference_tensor_operation/cpu/reference_mx_gemm.hpp" -#include "ck/library/utility/check_err.hpp" -#include "ck/library/utility/device_memory.hpp" -#include "ck/library/utility/fill.hpp" -#include "ck/library/utility/host_tensor.hpp" - -template -using S = ck::Sequence; - -using Row = ck::tensor_layout::gemm::RowMajor; -using Col = ck::tensor_layout::gemm::ColumnMajor; - -using PassThrough = ck::tensor_operation::element_wise::PassThrough; - -using ck::type_convert; - -struct ExecutionConfig final -{ - int do_verification = 1; // (0=no, 1=CPU) - int init_method = 2; // (0=constant values, 1=integer values, 2=decimal values) - bool time_kernel = false; // (0=no, 1=yes) - int verbosity = 0; // (0=no info, 1=verbose info) -}; - -struct ProblemSizeSplitK final -{ - - ck::index_t M = 3840; - ck::index_t N = 4096; - ck::index_t K = 4096; - - ck::index_t StrideA = -1; - ck::index_t StrideB = -1; - ck::index_t StrideC = -1; - - ck::index_t KBatch = 1; -}; - -bool parse_cmd_args(int argc, - char* argv[], - ProblemSizeSplitK& problem_size, - ExecutionConfig& config) -{ - if(argc == 1) - { - // use default case - } - else if(argc == 5) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.verbosity = std::stoi(argv[4]); - } - else if(argc >= 11) - { - config.do_verification = std::stoi(argv[1]); - config.init_method = std::stoi(argv[2]); - config.time_kernel = std::stoi(argv[3]); - config.verbosity = std::stoi(argv[4]); - - problem_size.M = std::stoi(argv[5]); - problem_size.N = std::stoi(argv[6]); - problem_size.K = std::stoi(argv[7]); - - problem_size.StrideA = std::stoi(argv[8]); - problem_size.StrideB = std::stoi(argv[9]); - problem_size.StrideC = std::stoi(argv[10]); - - if(argc >= 12) - { - problem_size.KBatch = std::stoi(argv[11]); - } - } - else - { - std::cerr << "arg1: verification (0=no, 1=CPU)" << std::endl - << "arg2: initialization (0=constant values, 1=integer values, 2=decimal values)" - << std::endl - << "arg3: time kernel (0=no, 1=yes)" << std::endl - << "arg4: verbosity (0=no info, 1=verbose info)" << std::endl - << "arg5 to 10: M(128x), N(128x), K(256x), StrideA, StrideB, StrideC" << std::endl - << "arg11: KBatch" << std::endl; - return false; - } - - return true; -} - -#if 1 -template -void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, int K) -{ - int MNXdlPack = 2; - int KXdlPack = 2; - - int XdlMNThread = 16; - int XdlKThread = 64 / XdlMNThread; - - int K0 = K / KXdlPack / XdlKThread; // KRepeat - - // The 4 16x128 building blocks will be packed into 1 32x256 for F4 - // The 8 16x16x128 mfma will be packed into 1 32x32x256 for F4 - - // unfold the MN32xK(256/32) scale buffer - // 4 16 2 2 - // To XdlKThread-> XdlMNThread -> KXdlPack -> MNXdlPack - // Then, MNRepeat->KRepeat - - for(int n = 0; n < MN; ++n) - { - for(int k = 0; k < K; ++k) - { - int n0 = n / (XdlMNThread * MNXdlPack); // i MNRepeat - int tempn = n % (XdlMNThread * MNXdlPack); - int n1 = tempn % XdlMNThread; // i XdlMNThread - int n2 = tempn / XdlMNThread; // i MNXdlPack - - int k0 = k / (XdlKThread * KXdlPack); // i KRepeat - int tempk = k % (XdlKThread * KXdlPack); - int k1 = tempk % XdlKThread; // i XdlKThread - int k2 = tempk / XdlKThread; // i KXdlPack - - int outputIndex = n0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread * K0 + - k0 * MNXdlPack * KXdlPack * XdlMNThread * XdlKThread + - k1 * MNXdlPack * KXdlPack * XdlMNThread + n1 * MNXdlPack * KXdlPack + - k2 * MNXdlPack + n2; - // src[n * K + k] = ck::type_convert(static_cast(powf(2.0f, n2 + - // k2 * MNXdlPack))); - if constexpr(KLast) - dst[outputIndex] = src[n * K + k]; - else - dst[outputIndex] = src[k * MN + n]; - } - } -} - -void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl) -{ - int KPack = 16; - int NLane = NXdl; - int KLane = 64 / NLane; - int K_pk = K / 2; - int K0 = K_pk / (KLane * KPack); - // K -> K0 KLane KPack - // N -> N0 NLane - // N, K -> N0 K0 KLane NLane KPack - int tempk; - for(int n = 0; n < N; ++n) - { - for(int k = 0; k < K_pk; ++k) - { - int n0 = n / NLane; - int n1 = n % NLane; - - int k0 = k / (KLane * KPack); - tempk = k % (KLane * KPack); - int k1 = tempk / KPack; - int k2 = tempk % KPack; - - int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + - k1 * KPack * NLane + n1 * KPack + k2; - - dst[outputIndex] = src[n * K_pk + k]; - } - } -} -#endif - -template -bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) -{ - - auto M = problem_size.M; - auto N = problem_size.N; - auto K = problem_size.K; - auto StrideA = problem_size.StrideA; - auto StrideB = problem_size.StrideB; - auto StrideC = problem_size.StrideC; - auto KBatch = problem_size.KBatch; - - auto f_host_tensor_descriptor = - [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { - if constexpr(std::is_same_v) - { - return HostTensorDescriptor({row, col}, {stride, 1}); - } - else - { - return HostTensorDescriptor({row, col}, {1, stride}); - } - }; - - auto f_get_default_stride = - [](ck::index_t row, ck::index_t col, ck::index_t stride, auto layout) { - if(stride == -1) - { - // give a chance if stride is -1, return a default packed stride - if constexpr(std::is_same_v) - { - return static_cast(col); - } - else - { - return static_cast(row); - } - } - else - return static_cast(stride); - }; - - StrideA = f_get_default_stride(M, K, StrideA, ALayout{}); - StrideB = f_get_default_stride(K, N, StrideB, BLayout{}); - StrideC = f_get_default_stride(M, N, StrideC, CLayout{}); - - if(K % ScaleBlockSize != 0) - { - throw std::runtime_error("wrong! K must be multiple of ScaleBlockSize."); - }; - - // Hardcode scale layouts as per pipeline assumptions - // TODO: Allow user to specify scale layouts - using AScaleLayout = Row; - using BScaleLayout = Col; - - const auto APackedSize = []() { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - return 2; - else - return 1; - }(); - - const auto BPackedSize = []() { - if constexpr(ck::is_same_v, ck::pk_i4_t> || - ck::is_same_v, ck::f4x2_pk_t>) - return 2; - else - return 1; - }(); - - auto Scale_Stride_AM = f_get_default_stride(M, K / ScaleBlockSize, -1, AScaleLayout{}); - auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); - - Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); - Tensor b_preshuffled( - f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size - - Tensor a_m_k_scale(f_host_tensor_descriptor( - M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A - Tensor b_k_n_scale(f_host_tensor_descriptor( - K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B - - Tensor a_shuffled_scale(f_host_tensor_descriptor( - M, K / ScaleBlockSize, Scale_Stride_AM, AScaleLayout{})); // scales for A - Tensor b_shuffled_scale(f_host_tensor_descriptor( - K / ScaleBlockSize, N, Scale_Stride_BN, BScaleLayout{})); // scales for B - - Tensor c_m_n_host_result( - f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // host verification - Tensor c_m_n_device_result( - f_host_tensor_descriptor(M, N, StrideC, CLayout{})); // device result downloaded to host - - if(config.verbosity >= 0) - { - std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; - std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; - std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; - std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; - } - - auto a_data_element = [](float x) { - if constexpr(ck::is_same_v) - return ck::type_convert(ck::float2_t(x)); - else - return ck::type_convert(x); - }; - auto b_data_element = [](float x) { - if constexpr(ck::is_same_v) - return ck::type_convert(ck::float2_t(x)); - else - return ck::type_convert(x); - }; - - switch(config.init_method) - { - case 0: // Initializations for development and debugging - ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); - ck::utils::FillConstant{b_data_element(2.0f)}(b_k_n); - ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n_scale); - if(config.verbosity > 0) - { - std::cout << "Init A = {1}" << std::endl; - std::cout << "Init A scale = {2.0}" << std::endl; - std::cout << "Init B = {0.5}" << std::endl; - std::cout << "Init B scale = {1.0}" << std::endl; - std::cout << "Expect C = {K}" << std::endl; - } - break; - - case 1: - ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); - ck::utils::FillConstant{b_data_element(1.0f)}(b_k_n); - // a_m_k_scale.GenerateTensorValue( - // GeneratorTensor_2{120, 129}); // scales: {0.25, 0.5, 1, 2} - // b_k_n_scale.GenerateTensorValue( - // GeneratorTensor_2{125, 129}); // scales: {0.25, 0.5, 1, 2} - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); - break; - case 2: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - ck::utils::FillConstant{b_data_element(1.0f)}(b_k_n); - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); - break; - case 3: - ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); - ck::utils::FillConstant{ck::type_convert(1.0f)}(b_k_n_scale); - break; - - case 4: - a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); - - b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); - b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); - break; - - default: - if(config.verbosity > 0) - { - std::cout << "NOTE: No input data initialization." << std::endl; - } - } - -#if 1 - preShuffleScaleBuffer>( - a_m_k_scale.mData.data(), a_shuffled_scale.mData.data(), M, K / ScaleBlockSize); - preShuffleScaleBuffer>( - b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); - - int NPerXdl = 16; // Fixed 16 - preShuffleBuffer(b_k_n.mData.data(), b_preshuffled.mData.data(), N, K, NPerXdl); -#endif - // printf("a:\n"); - // for(ck::index_t i = 0; i < M; i++) - // { - // for(ck::index_t j = 0; j < K; j += 2) - // { - // printf("%02x ", *reinterpret_cast(&a_m_k(i, j))); - // if(j % 32 == 31) - // { - // printf("\n"); - // } - // } - // printf("\n"); - // } - - // printf("b:\n"); - // for(ck::index_t i = 0; i < N; i++) - // { - // for(ck::index_t j = 0; j < K; j += 2) - // { - // printf("%02x ", *reinterpret_cast(&b_preshuffled(j, i))); - // if(j % 128 == 126) - // { - // printf("\n"); - // } - // } - // // printf("\n"); - // } - // printf("b_scale:\n"); - // for(ck::index_t i = 0; i < N; i++) - // { - // for(ck::index_t j = 0; j < K / ScaleBlockSize; j++) - // { - // // // b_k_n_scale(j, i) = - // // // ck::type_convert(static_cast(powf(2.0f, (j / 4) % 4))); - // // b_k_n_scale(j, i) =ck::type_convert(static_cast(1.0f)); - // // b_shuffled_scale(j, i) =ck::type_convert(static_cast(1.0f)); - // printf("%02x ", *reinterpret_cast(&b_k_n_scale(j, i))); - // } - // printf("\n"); - // } - - // printf("a_shuffled_scale:\n"); - // for(ck::index_t i = 0; i < M * K / ScaleBlockSize; i++) - // { - // printf("%02x ", *reinterpret_cast(&(a_shuffled_scale.mData.data()[i]))); - // if(i % 64 == 63) - // printf("\n"); - // } - // printf("b_shuffled_scale:\n"); - // for(ck::index_t i = 0; i < N * K / ScaleBlockSize; i++) - // { - // printf("%02x ", *reinterpret_cast(&(b_shuffled_scale.mData.data()[i]))); - // if(i % 64 == 63) - // printf("\n"); - // } - - if(config.verbosity > 0) - std::cout << "Device memory allocation..." << std::endl; - DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize()); - DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.GetElementSpaceSize()); - DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize()); - DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize()); - - if(config.verbosity > 0) - std::cout << "Upload data to device..." << std::endl; - a_device_buf.ToDevice(a_m_k.mData.data()); - a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data()); - b_device_buf.ToDevice(b_preshuffled.mData.data()); - b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data()); - - if(config.verbosity > 0) - std::cout << "Done." << std::endl; - - auto a_element_op = AElementOp{}; - auto b_element_op = BElementOp{}; - auto c_element_op = CElementOp{}; - - // run GEMM - auto device_op = DeviceOpInstance{}; - auto invoker = device_op.MakeInvoker(); - auto argument = - device_op.MakeArgument(static_cast(a_device_buf.GetDeviceBuffer()), - static_cast(a_scale_device_buf.GetDeviceBuffer()), - static_cast(b_device_buf.GetDeviceBuffer()), - static_cast(b_scale_device_buf.GetDeviceBuffer()), - static_cast(c_device_buf.GetDeviceBuffer()), - M, - N, - K, - StrideA, - Scale_Stride_AM, - StrideB, - Scale_Stride_BN, - StrideC, - KBatch, - a_element_op, - b_element_op, - c_element_op); - - if(!device_op.IsSupportedArgument(argument)) - { - throw std::runtime_error("wrong!\n" - "Provided combination of compilation and runtime parameters is " - "not consistent with the supported device_gemm arguments."); - } - - if(config.verbosity > 0) - { - std::cout << "Computing GEMM on device..." << std::endl << std::endl; - } - - float ave_time = - invoker.Run(argument, StreamConfig{nullptr, config.time_kernel, config.verbosity, 20, 50}); - - bool res_verified = true; - if(config.do_verification > 0) - { - c_device_buf.FromDevice(c_m_n_device_result.mData.data()); - if(config.verbosity > 0) - { - std::cout << "Done." << std::endl; - std::cout << "Computing GEMM on host..." << std::endl; - } - - using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceMXGemm; - auto ref_gemm = ReferenceGemmInstance{}; - auto ref_invoker = ref_gemm.MakeInvoker(); - - auto ref_argument = ref_gemm.MakeArgument(a_m_k, - a_m_k_scale, - b_k_n, - b_k_n_scale, - c_m_n_host_result, - PassThrough{}, - PassThrough{}, - PassThrough{}); - - ref_invoker.Run(ref_argument); - - if(config.verbosity > 0) - { - std::cout << "Done." << std::endl; - std::cout << "Comparing results..." << std::endl; - } - - // if(config.init_method == 0) - // { - // auto expected = static_cast(K); - // auto computed = type_convert(c_m_n_device_result(1, 12)); - - // res_verified = res_verified && std::abs(expected - computed) <= 0.0f; - // std::cout << "\nExpected vs Computed: " << expected << " vs " << computed - // << ((res_verified) ? " (PASSED!)" : " (FAILED!)") << std::endl - // << std::endl; - // } - - res_verified = res_verified && ck::utils::check_err(c_m_n_device_result, - c_m_n_host_result, - "Error: Incorrect results!"); - - if(config.verbosity > 0 && res_verified) - std::cout << "Verification Successful!" << std::endl; - } - else - { - if(config.verbosity > 0) - std::cout << "Done." << std::endl; - } - - if(config.time_kernel) - { - // Output size(M*N) * [dot product(2K) + product of scales(K/ScaleBlockSize) + scaling of - // partial sums(K/ScaleBlockSize)] - // FLOPS = 2 * M * N * K + 2 * M * N * K / ScaleBlockSize - std::size_t flop = std::size_t(2) * M * N * K + std::size_t(2) * M * N * K / ScaleBlockSize; - std::size_t num_btype = - sizeof(ADataType) * M * K / APackedSize + sizeof(BDataType) * K * N / BPackedSize + - sizeof(CDataType) * M * N + sizeof(XDataType) * M * K / ScaleBlockSize + - sizeof(XDataType) * N * K / ScaleBlockSize; - - float tflops = static_cast(flop) / 1.E9 / ave_time; - - float gb_per_sec = num_btype / 1.E6 / ave_time; - - std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec - << " GB/s, " << device_op.GetTypeString() << std::endl; - } - - return res_verified; -} - -template -bool run_mx_gemm_example(int argc, char* argv[]) -{ - ProblemSizeSplitK problem_size; - ExecutionConfig config; - - return parse_cmd_args(argc, argv, problem_size, config) && - run_mx_gemm(problem_size, config); -} diff --git a/example/67_gemm_microscaling/gemm_mx_common.hpp b/example/67_gemm_microscaling/gemm_mx_common.hpp index 8d52ecbe12..4c5fe4a12b 100644 --- a/example/67_gemm_microscaling/gemm_mx_common.hpp +++ b/example/67_gemm_microscaling/gemm_mx_common.hpp @@ -10,6 +10,7 @@ #include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" #include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp" #include "ck/library/utility/host_tensor_generator.hpp" #include "ck/utility/blkgemmpipe_scheduler.hpp" #include "ck/utility/data_type.hpp" @@ -154,6 +155,37 @@ void preShuffleScaleBuffer(ck::e8m0_bexp_t* src, ck::e8m0_bexp_t* dst, int MN, i } } } + +void preShuffleBuffer(const ck::f4x2_pk_t* src, ck::f4x2_pk_t* dst, int N, int K, int NXdl) +{ + int KPack = 16; + int NLane = NXdl; + int KLane = 64 / NLane; + int K_pk = K / 2; + int K0 = K_pk / (KLane * KPack); + // K -> K0 KLane KPack + // N -> N0 NLane + // N, K -> N0 K0 KLane NLane KPack + int tempk; + for(int n = 0; n < N; ++n) + { + for(int k = 0; k < K_pk; ++k) + { + int n0 = n / NLane; + int n1 = n % NLane; + + int k0 = k / (KLane * KPack); + tempk = k % (KLane * KPack); + int k1 = tempk / KPack; + int k2 = tempk % KPack; + + int outputIndex = n0 * KPack * NLane * KLane * K0 + k0 * KPack * NLane * KLane + + k1 * KPack * NLane + n1 * KPack + k2; + + dst[outputIndex] = src[n * K_pk + k]; + } + } +} #endif template + ck::index_t ScaleBlockSize, + bool BPreShuffle> bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& config) { @@ -221,7 +254,12 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto Scale_Stride_BN = f_get_default_stride(K / ScaleBlockSize, N, -1, BScaleLayout{}); Tensor a_m_k(f_host_tensor_descriptor(M, K, StrideA, ALayout{})); - Tensor b_k_n(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + auto b_k_n = + std::make_shared>(f_host_tensor_descriptor(K, N, StrideB, BLayout{})); + auto b_input = b_k_n; + if constexpr(BPreShuffle) + b_input = std::make_shared>( + f_host_tensor_descriptor(K, N, StrideB, BLayout{})); // use layout only for size // scales for A and B Tensor a_m_k_scale( @@ -244,7 +282,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c { std::cout << "a_m_k: " << a_m_k.mDesc << std::endl; std::cout << "a_m_k_scale: " << a_m_k_scale.mDesc << std::endl; - std::cout << "b_k_n: " << b_k_n.mDesc << std::endl; + std::cout << "b_k_n: " << b_k_n->mDesc << std::endl; std::cout << "b_k_n_scale: " << b_k_n_scale.mDesc << std::endl; std::cout << "c_m_n_device_result: " << c_m_n_device_result.mDesc << std::endl; } @@ -267,7 +305,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c case 0: // Initializations for development and debugging ck::utils::FillConstant{a_data_element(1.0f)}(a_m_k); ck::utils::FillConstant{ck::type_convert(1.0f)}(a_m_k_scale); - ck::utils::FillConstant{b_data_element(2.0f)}(b_k_n); + ck::utils::FillConstant{b_data_element(2.0f)}(*b_k_n); ck::utils::FillConstant{ck::type_convert(0.5f)}(b_k_n_scale); if(config.verbosity > 0) { @@ -281,8 +319,8 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c case 1: - a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] - b_k_n.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + a_m_k.GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] + b_k_n->GenerateTensorValue(GeneratorTensor_2{-5, 6}); // Z[-5,5] static_assert(ck::is_same_v); a_m_k_scale.GenerateTensorValue( GeneratorTensor_2{120, 129}); // scales: {0.25, 0.5, 1, 2} @@ -294,7 +332,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c a_m_k.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); a_m_k_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); - b_k_n.GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); + b_k_n->GenerateTensorValue(GeneratorTensor_3{-2.0, 2.0}); b_k_n_scale.GenerateTensorValue(GeneratorTensor_3{powf(2.0f, -125.0f), 1.0f}); break; @@ -310,6 +348,11 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c a_m_k_scale.mData.data(), a_shuffled_scale.mData.data(), M, K / ScaleBlockSize); preShuffleScaleBuffer>( b_k_n_scale.mData.data(), b_shuffled_scale.mData.data(), N, K / ScaleBlockSize); + if constexpr(BPreShuffle) + { + int NPerXdl = 16; // Fixed 16 + preShuffleBuffer(b_k_n->mData.data(), b_input->mData.data(), N, K, NPerXdl); + } #endif // printf("a_scale:\n"); // for(ck::index_t i = 0; i < M; i++) @@ -357,7 +400,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "Device memory allocation..." << std::endl; DeviceMem a_device_buf(sizeof(ADataType) * a_m_k.GetElementSpaceSize()); DeviceMem a_scale_device_buf(sizeof(XDataType) * a_m_k_scale.GetElementSpaceSize()); - DeviceMem b_device_buf(sizeof(BDataType) * b_k_n.GetElementSpaceSize()); + DeviceMem b_device_buf(sizeof(BDataType) * b_k_n->GetElementSpaceSize()); DeviceMem b_scale_device_buf(sizeof(XDataType) * b_k_n_scale.GetElementSpaceSize()); DeviceMem c_device_buf(sizeof(CDataType) * c_m_n_device_result.GetElementSpaceSize()); @@ -365,7 +408,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c std::cout << "Upload data to device..." << std::endl; a_device_buf.ToDevice(a_m_k.mData.data()); a_scale_device_buf.ToDevice(a_shuffled_scale.mData.data()); - b_device_buf.ToDevice(b_k_n.mData.data()); + b_device_buf.ToDevice(b_input->mData.data()); b_scale_device_buf.ToDevice(b_shuffled_scale.mData.data()); if(config.verbosity > 0) @@ -405,7 +448,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c } std::size_t total_size = - a_m_k.GetElementSpaceSizeInBytes() + b_k_n.GetElementSpaceSizeInBytes() + + a_m_k.GetElementSpaceSizeInBytes() + b_k_n->GetElementSpaceSizeInBytes() + a_m_k_scale.GetElementSpaceSizeInBytes() + b_k_n_scale.GetElementSpaceSizeInBytes() + a_shuffled_scale.GetElementSpaceSizeInBytes() + b_shuffled_scale.GetElementSpaceSizeInBytes(); @@ -450,7 +493,7 @@ bool run_mx_gemm(const ProblemSizeSplitK& problem_size, const ExecutionConfig& c auto ref_argument = ref_gemm.MakeArgument(a_m_k, a_m_k_scale, - b_k_n, + *b_k_n, b_k_n_scale, c_m_n_host_result, PassThrough{}, @@ -525,7 +568,8 @@ template + ck::index_t MXVectorSize, + bool BPreShuffle = false> bool run_mx_gemm_example(int argc, char* argv[]) { ProblemSizeSplitK problem_size; @@ -546,5 +590,6 @@ bool run_mx_gemm_example(int argc, char* argv[]) CElementOp, AccDataType, CShuffleDataType, - MXVectorSize>(problem_size, config); + MXVectorSize, + BPreShuffle>(problem_size, config); } diff --git a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp index d3ff33d63c..d458f02e65 100644 --- a/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp +++ b/example/67_gemm_microscaling/gemm_mx_fp4_bpreshuffle.cpp @@ -1,7 +1,7 @@ // SPDX-License-Identifier: MIT // Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. -#include "gemm_mx_bpreshuffle_common.hpp" +#include "gemm_mx_common.hpp" using ADataType = ck::f4x2_pk_t; using BDataType = ck::f4x2_pk_t; @@ -99,7 +99,8 @@ int main(int argc, char* argv[]) CElementOp, AccDataType, CShuffleDataType, - ScaleBlockSize>(argc, argv) + ScaleBlockSize, + true>(argc, argv) ? 0 : -1; } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp index b8c0287783..7d21c44504 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_mx_bpreshuffle_selector.hpp @@ -3,41 +3,9 @@ #pragma once -// #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_mx.hpp" #include "ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_mx_bpreshuffle.hpp" namespace ck { - -/** - * @brief Define matrix data types that have hardware support for MX GEMMs - */ -template -static constexpr bool is_scale_mfma_data_type() -{ - using U = element_type_t; - return is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v; -} - -/** - * @brief Define scale data types that have hardware support for MX GEMMs - */ -template -static constexpr bool is_scale_mfma_scale_type() -{ - return is_same_v; -} - -/** - * @brief Combination of data types that have hardware support for MX GEMMs - */ -template -static constexpr bool scale_mfma_hw_support() -{ - return is_scale_mfma_data_type() && is_scale_mfma_data_type() && - is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); -} - template -static constexpr bool is_scale_mfma_data_type() -{ - using U = element_type_t; - return is_same_v || is_same_v || is_same_v || - is_same_v || is_same_v; -} - -/** - * @brief Define scale data types that have hardware support for MX GEMMs - */ -template -static constexpr bool is_scale_mfma_scale_type() -{ - return is_same_v; -} - -/** - * @brief Combination of data types that have hardware support for MX GEMMs - */ -template -static constexpr bool scale_mfma_hw_support() -{ - return is_scale_mfma_data_type() && is_scale_mfma_data_type() && - is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); -} - template 1) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_b_preshuffle; Run(kernel); } } @@ -337,7 +337,7 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmMX; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } else @@ -399,20 +399,20 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmMX 1) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; Run(kernel); } else { const auto kernel = - kernel_gemm_xdl_cshuffle_v3; + kernel_gemm_xdl_cshuffle_v3_b_preshuffle; Run(kernel); } } @@ -420,22 +420,22 @@ struct DeviceGemmMX_Xdl_CShuffleV3_BPreshuffle : public DeviceGemmMX; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; Run(kernel); } else { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3_2lds; + const auto kernel = kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; Run(kernel); } } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp index aab922abda..a603b1bd63 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_mx_bpreshuffle.hpp @@ -34,7 +34,7 @@ __global__ void __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_b_preshuffle(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) __shared__ char p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte()]; @@ -65,7 +65,7 @@ __global__ void __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif // __attribute__((amdgpu_waves_per_eu(1, 1))) - kernel_gemm_xdl_cshuffle_v3_2lds(typename GridwiseGemm::Argument karg) + kernel_gemm_xdl_cshuffle_v3_b_preshuffle_2lds(typename GridwiseGemm::Argument karg) { #if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx9__)) // Pass two lds pointer is the key to tell compiler that ds_read/write diff --git a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp index 549d69257b..9248af0a4b 100644 --- a/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp +++ b/include/ck/tensor_operation/gpu/warp/xdlops_gemm.hpp @@ -8,6 +8,35 @@ #include "ck/utility/amd_xdlops.hpp" namespace ck { +/** + * @brief Define matrix data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_data_type() +{ + using U = element_type_t; + return is_same_v || is_same_v || is_same_v || + is_same_v || is_same_v; +} + +/** + * @brief Define scale data types that have hardware support for MX GEMMs + */ +template +static constexpr bool is_scale_mfma_scale_type() +{ + return is_same_v; +} + +/** + * @brief Combination of data types that have hardware support for MX GEMMs + */ +template +static constexpr bool scale_mfma_hw_support() +{ + return is_scale_mfma_data_type() && is_scale_mfma_data_type() && + is_scale_mfma_scale_type() && is_scale_mfma_scale_type(); +} enum struct MfmaInstr {