From 87cf3a4fe274e3c778ffa39e3c02e6495da7392d Mon Sep 17 00:00:00 2001 From: Enrico Degregori <73224202+EnricoDeg@users.noreply.github.com> Date: Thu, 11 Dec 2025 09:06:20 +0100 Subject: [PATCH] Wmma support for gemm_ab_scale (#3314) * Support gemm_ab_scale: - Add tests - Integrate scaling implementation in multiple D - Generalize existing b_scale for ab_scale - Add instances - Generalize implementation for ScaleBlockM, ScaleBlockN, ScaleBlockK - Add support for all layouts supported by xdl - Fix splitk xdl * Fix copyright * Wmma support for gemm_blockscale_wp (#3315) * Support for preshuffle with ab scale - add support for b preshuffle in GridwiseGemm_wmma_cshuffle_v3_ab_scale - add support for AScaleLayout amnd BScaleLayout (can be different from ALayout and BLayout, respectively) - add Run method in v1 pipeline to support preshuffle + scaling - add support for preshuffle gemms in common invoker - Add splitk support * Fix copyright header [ROCm/composable_kernel commit: ce99cab6056d1ffef5acb6f4ad7ede87a46a3cfc] --- .../65_gemm_multiply_multiply/CMakeLists.txt | 2 + ...mm_multiply_multiply_wmma_fp8_ab_scale.cpp | 345 +++++++++++++ ...ltiply_wmma_fp8_blockscale_bpreshuffle.cpp | 357 +++++++++++++ .../blockwise_gemm_pipeline_wmmaops_base.hpp | 146 ++++-- .../blockwise_gemm_pipeline_wmmaops_v1.hpp | 468 +++++++++++++++++- .../blockwise_gemm_pipeline_wmmaops_v3.hpp | 345 ++++++++++++- .../device_gemm_multiple_d_ab_scale.hpp | 347 +++++++++++++ ..._batched_gemm_wmma_cshuffle_v3_b_scale.hpp | 11 +- ...m_multiple_d_wmma_cshuffle_v3_ab_scale.hpp | 362 ++++++++++++++ ...ltiple_d_wmma_cshuffle_v3_b_preshuffle.hpp | 308 +----------- ...mma_cshuffle_v3_blockscale_bpreshuffle.hpp | 360 ++++++++++++++ ...mm_multiple_d_xdl_cshuffle_v3_ab_scale.hpp | 102 +++- .../device_gemm_wmma_cshuffle_v3_b_scale.hpp | 10 +- .../device_gemm_wmma_cshuffle_v3_common.hpp | 200 ++++++-- .../gridwise_ab_transfer_thread_tiles.hpp | 10 +- ...ise_batched_gemm_gemm_wmma_cshuffle_v3.hpp | 6 +- .../grid/gridwise_gemm_wmma_cshuffle_v3.hpp | 7 +- ...idwise_gemm_wmma_cshuffle_v3_ab_scale.hpp} | 393 +++++++++++---- .../gridwise_gemm_wmma_cshuffle_v3_common.hpp | 47 +- ..._gemm_xdl_cshuffle_v3_multi_d_ab_scale.hpp | 74 ++- .../gpu/gemm_ab_scale.hpp | 394 ++++++++++++++- .../gpu/gemm_blockscale_wp.hpp | 147 ++++++ .../gpu/CMakeLists.txt | 12 +- .../gpu/gemm_ab_scale/CMakeLists.txt | 21 +- ...e_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp | 79 +++ ...n_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ ...e_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp | 80 +++ ...n_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ ...e_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 95 ++++ ...k_mn_128_128_128_comp_default_instance.cpp | 37 ++ ..._mn_128_128_128_comp_kpadding_instance.cpp | 37 ++ ...mn_128_128_128_mem_v1_default_instance.cpp | 38 ++ ...n_128_128_128_mem_v1_kpadding_instance.cpp | 38 ++ .../gpu/gemm_blockscale_wp/CMakeLists.txt | 5 +- ...p_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp | 77 +++ ...k_mn_128_128_128_comp_default_instance.cpp | 38 ++ ...nk_mn_128_128_128_mem_default_instance.cpp | 38 ++ .../profiler/profile_gemm_ab_scale_impl.hpp | 6 +- .../profile_gemm_blockscale_wp_impl.hpp | 2 +- test/CMakeLists.txt | 1 + test/gemm_ab_scale/CMakeLists.txt | 9 + test/gemm_ab_scale/test_gemm_ab_scale.cpp | 236 +++++++++ .../gemm_ab_scale/test_gemm_ab_scale_util.hpp | 102 ++++ test/gemm_blockscale_wp/CMakeLists.txt | 4 +- ...p8.cpp => test_gemm_blockscale_wp_fp8.cpp} | 0 51 files changed, 5144 insertions(+), 552 deletions(-) create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp create mode 100644 example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp create mode 100644 include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp rename include/ck/tensor_operation/gpu/grid/{gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp => gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp} (58%) create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp create mode 100644 library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp create mode 100644 test/gemm_ab_scale/CMakeLists.txt create mode 100644 test/gemm_ab_scale/test_gemm_ab_scale.cpp create mode 100644 test/gemm_ab_scale/test_gemm_ab_scale_util.hpp rename test/gemm_blockscale_wp/{test_gemm_blockscale_wp_xdl_fp8.cpp => test_gemm_blockscale_wp_fp8.cpp} (100%) diff --git a/example/65_gemm_multiply_multiply/CMakeLists.txt b/example/65_gemm_multiply_multiply/CMakeLists.txt index abfbe115fb..944a8f96bf 100644 --- a/example/65_gemm_multiply_multiply/CMakeLists.txt +++ b/example/65_gemm_multiply_multiply/CMakeLists.txt @@ -77,3 +77,5 @@ example_compile_options(example_moe_gemm1_xdl_fp8_blockscale PRIVATE ${BLOCKSCAL add_example_executable(example_gemm_add_add_wmma_fp16 gemm_add_add_wmma_fp16.cpp) add_example_executable(example_gemm_multiply_multiply_wmma_fp16_bpreshuffle gemm_multiply_multiply_wmma_fp16_bpreshuffle.cpp) add_example_executable(example_gemm_multiply_multiply_wmma_fp8_bpreshuffle gemm_multiply_multiply_wmma_fp8_bpreshuffle.cpp) +add_example_executable(example_gemm_multiply_multiply_wmma_fp8_ab_scale gemm_multiply_multiply_wmma_fp8_ab_scale.cpp) +add_example_executable(example_gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp) diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp new file mode 100644 index 0000000000..0fb7a70781 --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_ab_scale.cpp @@ -0,0 +1,345 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using B0Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3 + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, + 1, 1, S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v3, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8 || argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); + exit(0); + } + + ck::index_t Scale_Stride_AM = (K + Scale_Block_K - 1) / Scale_Block_K; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + ck::Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + ck::Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AM, + A0Layout{})); + ck::Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + ck::Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + ck::Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + ck::Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 1}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + ck::DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b0_device_buf.ToDevice(b0_k_n.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + std::string op_name = device_op.GetTypeString(); + auto invoker = device_op.MakeInvoker(); + auto argument = + device_op.MakeArgument(static_cast(a0_device_buf.GetDeviceBuffer()), + static_cast(b0_device_buf.GetDeviceBuffer()), + std::array{}, + static_cast(e_device_buf.GetDeviceBuffer()), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + static_cast(a1_device_buf.GetDeviceBuffer()), + static_cast(b1_device_buf.GetDeviceBuffer()), + a_element_op, + b_element_op, + cde_element_op, + KBatch); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = .0; + + ave_time = invoker.Run(argument, StreamConfig{nullptr, false, 0, 50, 100}); + + int pass = 0; + + if(do_verification) + { + ck::Tensor c_m_n({M, N}); + ck::Tensor a_m_k({M, K}); + ck::Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + pass = ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op_name << ", KBatch " << KBatch << std::endl; + + return pass; +} diff --git a/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp new file mode 100644 index 0000000000..ba95724d3f --- /dev/null +++ b/example/65_gemm_multiply_multiply/gemm_multiply_multiply_wmma_fp8_blockscale_bpreshuffle.cpp @@ -0,0 +1,357 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +#include "common.hpp" + +template +using S = ck::Sequence; + +using BF16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using A0DataType = FP8; +using A1DataType = F32; +using B0DataType = FP8; +using B1DataType = F32; +using AccDataType = F32; +using CShuffleDataType = F32; +using DsDataType = ck::Tuple<>; +using EDataType = BF16; + +using A0Layout = Row; +using A1Layout = Col; +using B0Layout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple<>; +using ELayout = Row; + +static constexpr int KPack = 16; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; +using CDEElementOp = PassThrough; + +static constexpr auto GemmSpec = ck::tensor_operation::device::GemmSpecialization::Default; + +static constexpr ck::index_t Scale_Block_M = 1; +static constexpr ck::index_t Scale_Block_N = 128; +static constexpr ck::index_t Scale_Block_K = 128; + +using DeviceOpInstance = + ck::tensor_operation::device::DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle + // clang-format off + , S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, + 2, 16, 16, 0, + 1, 1, + S<1, 32, 1, 8>, S<8>, + ck::BlockGemmPipelineScheduler::Intrawave, ck::BlockGemmPipelineVersion::v1, FP8>; +// clang-format on + +int main(int argc, char* argv[]) +{ + bool do_verification = true; + int init_method = 1; + bool time_kernel = false; + bool flush_cache = true; + + // GEMM shape + ck::index_t M = 128; + ck::index_t N = 1024; + ck::index_t K = 1024; + + ck::index_t StrideA = K; + ck::index_t StrideB = K; + ck::index_t StrideE = N; + + ck::index_t KBatch = 1; + + if(argc == 1) + { + // use default case + } + else if(argc == 4) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + } + else if(argc == 8 || argc == 9) + { + do_verification = std::stoi(argv[1]); + init_method = std::stoi(argv[2]); + time_kernel = std::stoi(argv[3]); + + M = std::stoi(argv[4]); + N = std::stoi(argv[5]); + K = std::stoi(argv[6]); + + flush_cache = std::stoi(argv[7]); + + if(argc == 9) + { + KBatch = std::stoi(argv[8]); + } + + StrideA = K; + StrideB = K; + StrideE = N; + } + else + { + printf("arg1: verification (0=no, 1=yes)\n"); + printf("arg2: initialization (0=no init, 1=integer value, 2=decimal value)\n"); + printf("arg3: time kernel (0=no, 1=yes)\n"); + printf("arg4 to 6: M, N, K\n"); + printf("arg7: flush both I$ and L2$ (0=no, 1=yes)\n"); + printf("arg8: KBatch (default: 1)\n"); + exit(0); + } + + // Transpose the AScale tensor for better performance + ck::index_t Scale_Stride_AK = (M + Scale_Block_M - 1) / Scale_Block_M; + ck::index_t Scale_Stride_BN = (K + Scale_Block_K - 1) / Scale_Block_K; + + auto f_host_tensor_descriptor = + [](std::size_t row, std::size_t col, std::size_t stride, auto layout) { + using namespace ck::literals; + + if(std::is_same::value) + { + return ck::HostTensorDescriptor({row, col}, {stride, 1_uz}); + } + else + { + return ck::HostTensorDescriptor({row, col}, {1_uz, stride}); + } + }; + + ck::Tensor a0_m_k(f_host_tensor_descriptor(M, K, StrideA, A0Layout{})); + ck::Tensor a1_m_k(f_host_tensor_descriptor((M + Scale_Block_M - 1) / Scale_Block_M, + (K + Scale_Block_K - 1) / Scale_Block_K, + Scale_Stride_AK, + A1Layout{})); + ck::Tensor b0_k_n(f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); + ck::Tensor b0_preshuffled( + f_host_tensor_descriptor(K, N, StrideB, B0Layout{})); // use laout only for size + ck::Tensor b1_k_n(f_host_tensor_descriptor((K + Scale_Block_K - 1) / Scale_Block_K, + (N + Scale_Block_N - 1) / Scale_Block_N, + Scale_Stride_BN, + B0Layout{})); + ck::Tensor e_m_n_host_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + ck::Tensor e_m_n_device_result(f_host_tensor_descriptor(M, N, StrideE, ELayout{})); + + std::cout << "a0_m_k: " << a0_m_k.mDesc << std::endl; + std::cout << "a1_m_k: " << a1_m_k.mDesc << std::endl; + std::cout << "b0_k_n: " << b0_k_n.mDesc << std::endl; + std::cout << "b1_k_n: " << b1_k_n.mDesc << std::endl; + std::cout << "e_m_n: " << e_m_n_host_result.mDesc << std::endl; + + switch(init_method) + { + case 0: break; + case 1: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + case 2: + a0_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b0_k_n.GenerateTensorValue(GeneratorTensor_1{}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 3: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 4: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_1{}); + break; + case 5: + a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); + a1_m_k.GenerateTensorValue(GeneratorTensor_1{}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + break; + default: + a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + b0_k_n.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); + a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + } + + ck::DeviceMem a0_device_buf(sizeof(A0DataType) * a0_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem a1_device_buf(sizeof(A1DataType) * a1_m_k.mDesc.GetElementSpaceSize()); + ck::DeviceMem b0_device_buf(sizeof(B0DataType) * b0_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem b1_device_buf(sizeof(B1DataType) * b1_k_n.mDesc.GetElementSpaceSize()); + ck::DeviceMem e_device_buf(sizeof(EDataType) * e_m_n_device_result.mDesc.GetElementSpaceSize()); + + a0_device_buf.ToDevice(a0_m_k.mData.data()); + a1_device_buf.ToDevice(a1_m_k.mData.data()); + b1_device_buf.ToDevice(b1_k_n.mData.data()); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DsDataType::Size(); + + // do GEMM + auto device_op = DeviceOpInstance{}; + std::string op_name = device_op.GetTypeString(); + int NPerWmma = device_op.GetPreShuffleParameters(); + + preShuffleBuffer(b0_k_n.mData.data(), b0_preshuffled.mData.data(), N, K, NPerWmma); + + b0_device_buf.ToDevice(b0_preshuffled.mData.data()); + auto invoker = device_op.MakeInvoker(); + auto argument = device_op.MakeArgument(a0_device_buf.GetDeviceBuffer(), + b0_device_buf.GetDeviceBuffer(), + std::array{}, + e_device_buf.GetDeviceBuffer(), + M, + N, + K, + StrideA, + StrideB, + std::array{}, + StrideE, + a1_device_buf.GetDeviceBuffer(), + b1_device_buf.GetDeviceBuffer(), + a_element_op, + b_element_op, + cde_element_op, + KBatch); + + if(!device_op.IsSupportedArgument(argument)) + { + throw std::runtime_error( + "wrong! device_gemm with the specified compilation parameters does " + "not support this GEMM problem"); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_btype = + sizeof(A0DataType) * M * K + sizeof(B0DataType) * K * N + sizeof(EDataType) * M * N; + + float ave_time = 0.0f; + + if(flush_cache) + { + int rotating_buf = (512 * 1024 * 1024 + num_btype - 1) / num_btype; + + ave_time = invoker.Run(argument, + StreamConfig{nullptr, time_kernel, 0, 50, 100, true, rotating_buf}); + } + else + { + ave_time = invoker.Run(argument, StreamConfig{nullptr, time_kernel, 0, 50, 100}); + } + + float tflops = static_cast(flop) / 1.E9 / ave_time; + + float gb_per_sec = num_btype / 1.E6 / ave_time; + + std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " + << op_name << ", KBatch " << KBatch << std::endl; + + if(do_verification) + { + ck::Tensor c_m_n({M, N}); + ck::Tensor a_m_k({M, K}); + ck::Tensor b_k_n({K, N}); + + for(int m = 0; m < M; m++) + { + for(int k = 0; k < K; k++) + { + a_m_k(m, k) = ck::type_convert(a0_m_k(m, k)) * + a1_m_k(m / Scale_Block_M, k / Scale_Block_K); + } + } + + for(int n = 0; n < N; n++) + { + for(int k = 0; k < K; k++) + { + b_k_n(k, n) = ck::type_convert(b0_k_n(k, n)) * + b1_k_n(k / Scale_Block_K, n / Scale_Block_N); + } + } + + using ReferenceGemmInstance = ck::tensor_operation::host::ReferenceGemm; + auto ref_gemm = ReferenceGemmInstance{}; + auto ref_invoker = ref_gemm.MakeInvoker(); + + auto ref_argument = + ref_gemm.MakeArgument(a_m_k, b_k_n, c_m_n, PassThrough{}, PassThrough{}, PassThrough{}); + + ref_invoker.Run(ref_argument); + + for(int m = 0; m < M; ++m) + { + for(int n = 0; n < N; ++n) + { + e_m_n_host_result(m, n) = ck::type_convert(c_m_n(m, n)); + } + } + + e_device_buf.FromDevice(e_m_n_device_result.mData.data()); + + return ck::utils::check_err( + e_m_n_device_result, e_m_n_host_result, "Error: Incorrect results!", 5e-2, 5e-2) + ? 0 + : 1; + } + + return 0; +} diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp index f24a1eb3bc..f831c0f6cf 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_base.hpp @@ -109,65 +109,145 @@ struct BlockwiseGemmWmmaops_pipeline_base } }; - template - struct BScale + typename ThreadDesc> + struct ABScale { - __device__ BScale(GridDesc b_scale_grid_desc_, - ThreadCopy b_scale_thread_copy_, - GridBuffer b_scale_grid_buf_) - : b_scale_thread_copy(b_scale_thread_copy_), - b_scale_grid_desc(b_scale_grid_desc_), - b_scale_grid_buf(b_scale_grid_buf_) {}; + __device__ ABScale(GridDesc scale_grid_desc_, + ThreadCopy scale_thread_copy_, + GridBuffer scale_grid_buf_) + : scale_thread_copy(scale_thread_copy_), + scale_grid_desc(scale_grid_desc_), + scale_grid_buf(scale_grid_buf_) {}; - static constexpr index_t num_scale_k_block = BScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr index_t num_scale_k_block = ThreadDesc{}.GetLength(Number<1>{}); static constexpr index_t num_scale_krepeat = KRepeat / num_scale_k_block; - static constexpr auto b_scale_thread_desc = BScaleThreadDesc{}; + static constexpr index_t num_slice_mn = ScaleSliceSizeMN; + static constexpr index_t num_slice_k = ScaleSliceSizeK; + static constexpr index_t reg_size_per_wmma = RegSizePerWmma; - static constexpr auto b_scale_thread_copy_step = - make_tuple(make_multi_index(NWaves * NPerWmma, 0), - make_multi_index(-NPerBlock, 0), - make_multi_index(-NPerBlock, (KPerBlock + ScaleBlockK - 1) / ScaleBlockK)); + static constexpr auto scale_thread_desc = ThreadDesc{}; + + static constexpr auto scale_thread_copy_step = + make_tuple(make_multi_index(ScaleSliceStrideMN, 0), + make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, 0), + make_multi_index(-ScaleSliceSizeMN / RegSizePerWmma * ScaleSliceStrideMN, + ScaleSliceSizeK)); template __device__ void GlobalLoad(bool cond) { - static_for<0, NRepeat, 1>{}([&](auto n0) { - b_scale_thread_copy.Run(b_scale_grid_desc, - b_scale_grid_buf, - b_scale_thread_desc, - make_tuple(n0, Number<0>{}), - b_scale_thread_bufs(Number{})); + static_for<0, ScaleSliceSizeMN / RegSizePerWmma, 1>{}([&](auto m0) { + scale_thread_copy.Run(scale_grid_desc, + scale_grid_buf, + scale_thread_desc, + make_tuple(m0 * Number{}, Number<0>{}), + scale_thread_bufs(Number{})); - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<0>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<0>{})); }); if(cond) { - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<2>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<2>{})); } else { - b_scale_thread_copy.MoveSrcSliceWindow(b_scale_grid_desc, - b_scale_thread_copy_step.At(Number<1>{})); + scale_thread_copy.MoveSrcSliceWindow(scale_grid_desc, + scale_thread_copy_step.At(Number<1>{})); } } - ThreadCopy b_scale_thread_copy; - GridDesc b_scale_grid_desc; - GridBuffer b_scale_grid_buf; - StaticallyIndexedArray{}> b_scale_thread_bufs; + ThreadCopy scale_thread_copy; + GridDesc scale_grid_desc; + GridBuffer scale_grid_buf; + StaticallyIndexedArray{}> scale_thread_bufs; + }; + + template + struct CScale + { + __device__ CScale() {} + + static constexpr auto reg_size_per_wmma = + ck::is_same_v && ck::is_same_v + ? 1 + : wmma_gemm.GetRegSizePerWmma(); + static constexpr auto c_scale_thread_desc = make_naive_tensor_descriptor_packed(make_tuple( + Number{}, + Number{}, + Number{})); + using CScaleThreadDesc = decltype(c_scale_thread_desc); + static constexpr auto num_scale_k_block = CScaleThreadDesc{}.GetLength(Number<0>{}); + static constexpr auto num_scale_m_block = CScaleThreadDesc{}.GetLength(Number<1>{}); + static constexpr auto num_scale_n_block = CScaleThreadDesc{}.GetLength(Number<2>{}); + using ThreadStaticBuffer = decltype(make_static_buffer( + c_scale_thread_desc.GetElementSpaceSize())); + + __device__ void Load(AScaleStruct& a_scale_struct, BScaleStruct& b_scale_struct) + { + using AScaleThreadDesc = decltype(AScaleStruct::scale_thread_desc); + using BScaleThreadDesc = decltype(BScaleStruct::scale_thread_desc); + + static_for<0, num_scale_m_block, 1>{}([&](auto m0) { + static_for<0, num_scale_n_block, 1>{}([&](auto n0) { + static_for<0, num_scale_k_block, 1>{}([&](auto k0) { + constexpr index_t c_offset = + CScaleThreadDesc{}.CalculateOffset(make_tuple(k0, m0, n0)); + constexpr index_t a_offset = + AScaleThreadDesc{}.CalculateOffset(make_tuple(m0, k0)); + constexpr index_t b_offset = + BScaleThreadDesc{}.CalculateOffset(make_tuple(n0, k0)); + + c_scale_thread_bufs(I0)(Number{}) = + a_scale_struct.scale_thread_bufs(I0)[Number{}] * + b_scale_struct.scale_thread_bufs(I0)[Number{}]; + }); + }); + }); + } + + __device__ void Clear() + { + static_for<0, reg_size_per_wmma, 1>{}([&](auto t) { + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()(Number{}) = 0; + }); + } + + template + __device__ void UpdateCThreadBuf(CThreadBuf& c_thread_buf) + { + static_for<0, reg_size_per_wmma, 1>{}([&](auto t) { + constexpr index_t c_offset = + c_thread_desc_.CalculateOffset(make_tuple(m_index, n_index, t)); + constexpr index_t cscale_offset = CScaleThreadDesc{}.CalculateOffset(make_tuple( + k_index, + (m_index * num_scale_m_block / MRepeat) % num_scale_m_block + + (Number{}) % + AScaleStruct::reg_size_per_wmma, + (n_index * num_scale_n_block / NRepeat) % num_scale_n_block)); + c_thread_buf(Number{}) += + c_thread_buf_per_scale.GetVectorTypeReference(Number<0>{}) + .template AsType()[Number{}] * + type_convert(c_scale_thread_bufs(I0)[Number{}]); + }); + } + + StaticallyIndexedArray{}> c_scale_thread_bufs; + StaticBufferTupleOfVector + c_thread_buf_per_scale; }; __host__ __device__ constexpr auto& GetCThreadBuffer() { return c_thread_buf_; } diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp index 0f62aee0a8..3b12e7feb0 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v1.hpp @@ -174,7 +174,9 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -188,7 +190,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); // Local prefill 1 @@ -217,6 +220,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}([&](auto k0) { static_for<0, MRepeat, 1>{}([&](auto m0) { a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, @@ -245,7 +249,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}], b_thread_desc_, @@ -366,6 +370,189 @@ struct BlockwiseGemmWmmaops_pipeline_v1 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + Base::a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + Base::b_thread_desc_.GetElementSpaceSize()); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Initialize C + c_thread_buf.Clear(); + + auto blockwise_gemm_func = [&]() { + // Local load + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + Base::a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + Base::a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + Base::b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + Base::b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); + }); + }); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + }; + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + + block_sync_lds(); + blockwise_gemm_func(); + + block_sync_lds(); + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + i += 1; + } while(i < (num_loop - 1)); + } + + // tail + if constexpr(TailNum == TailNumber::Full) + { + block_sync_lds(); + blockwise_gemm_func(); + } + } + protected: // A[MRepeat, I1, I1, KPack] static constexpr auto a_thread_desc_ = make_naive_tensor_descriptor_packed( @@ -528,6 +715,23 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + struct KLoopParams + { + static constexpr auto KRepeatNoScale = 1; + static constexpr auto NumScaleKBlock = + Number{}; + static constexpr auto KRepeatPerNumScaleKBlock = KRepeatPerCluster / NumScaleKBlock; + }; + + template <> + struct KLoopParams + { + static constexpr index_t KRepeatNoScale = KRepeatPerCluster; + static constexpr index_t NumScaleKBlock = 1; + static constexpr index_t KRepeatPerNumScaleKBlock = 1; + }; + template + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -557,7 +763,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1(num_loop_per_scale == 1); // Local prefill 1 @@ -615,7 +822,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1{}], b_thread_desc_, @@ -704,6 +911,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -996,7 +1206,7 @@ struct BlockwiseGemmWmmaops_pipeline_v1 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc&, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer&, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + StaticallyIndexedArray{}> b_thread_bufs; + constexpr auto b_block_origin_idx = make_tuple(I0, I0, I0, I0, I0, I0, I0); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + auto gemm_core_func = [&](auto reg_buf) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_bufs[reg_buf] + [Number{}, + I0, + I0, + n0, + I0, + k_index, + Number{}))>{}]; + }); + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + }; + + auto a_local_prefetch_func = [&]() { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, KRepeat, 1>{}([&](auto k0) { + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + }); + }; + + // Global prefetch A1 B1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(I0)); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + __builtin_amdgcn_sched_barrier(0); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill A1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + // Global prefetch A2 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, I0); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + // Local prefetch A1 + block_sync_lds(); + a_local_prefetch_func(); + + // Initialize C + c_thread_buf.Clear(); + + __builtin_amdgcn_sched_barrier(0); + + // main body + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + auto LoopFunc = [&](auto wmma_reg_buf, auto local_read_buf) { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(local_read_buf)); + + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf, wmma_reg_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf, local_read_buf); + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + + a_scale_struct.template GlobalLoad<0>( + (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>( + (i + 2 + wmma_reg_buf) % num_loop_per_scale == 0); + + gemm_core_func(wmma_reg_buf); + + block_sync_lds(); + + // loop prefetch copy + a_local_prefetch_func(); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + }; + + LoopFunc(I0, I1); + LoopFunc(I1, I0); + + i += 2; + } while(i < (num_loop - 2)); + } + + // tail + if constexpr(TailNum == TailNumber::Even) + { + b_blockwise_copy.Run(b_grid_desc, + b_grid_buf, + b_block_desc_k0_n0_n1_n2_k1, + b_block_origin_idx, + b_thread_bufs(I1)); + + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + + a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + + gemm_core_func(I0); + + block_sync_lds(); + + // tail Local Prefetch A1 + a_local_prefetch_func(); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + __builtin_amdgcn_sched_barrier(0); + + gemm_core_func(I1); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + else if constexpr(TailNum == TailNumber::Odd) + { + gemm_core_func(I0); + } + } + protected: static constexpr auto b_thread_desc_ = make_naive_tensor_descriptor_packed(make_tuple(Number{}, diff --git a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp index 08c765dd0a..b8d451363e 100644 --- a/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp +++ b/include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_wmmaops_v3.hpp @@ -123,6 +123,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3; using Base::I0; + using Base::I1; + using Base::I2; + using Base::I3; using Base::A_K1; using Base::A_KRow; @@ -322,7 +325,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3{}], b_thread_desc_, @@ -348,7 +351,9 @@ struct BlockwiseGemmWmmaops_pipeline_v3 + typename AScaleStruct, + typename BScaleStruct, + typename enable_if, bool>::type = false> __device__ void Run(const AGridDesc& a_grid_desc, const ABlockDesc& a_block_desc, ABlockTransfer& a_blockwise_copy, @@ -362,7 +367,7 @@ struct BlockwiseGemmWmmaops_pipeline_v3(num_loop_per_scale == 1); // Local prefill 1 @@ -611,6 +617,339 @@ struct BlockwiseGemmWmmaops_pipeline_v3 && + !ck::is_same_v, + bool>::type = false> + __device__ void Run(const AGridDesc& a_grid_desc, + const ABlockDesc& a_block_desc, + ABlockTransfer& a_blockwise_copy, + const AGridBuffer& a_grid_buf, + ABlockBuffer& a_block_buf, + const ABlockTransferStep& a_block_copy_step, + const BGridDesc& b_grid_desc, + const BBlockDesc& b_block_desc, + BBlockTransfer& b_blockwise_copy, + const BGridBuffer& b_grid_buf, + BBlockBuffer& b_block_buf, + const BBlockTransferStep& b_block_copy_step, + CThreadBuffer& c_thread_buf, + AScaleStruct& a_scale_struct, + BScaleStruct& b_scale_struct, + index_t num_loop, + index_t num_loop_per_scale) const + { + __builtin_amdgcn_sched_barrier(0); + + constexpr index_t KPerWaveBlock = wmma_gemm.GetKPerWaveBlk(); + static constexpr auto NumScaleKBlock = + Number{}; + + auto a_thread_buf = make_static_buffer( + a_thread_desc_.GetElementSpaceSize()); + auto b_thread_buf = make_static_buffer( + b_thread_desc_.GetElementSpaceSize()); + + using CScaleStruct = typename Base::template CScale; + auto c_scale_struct = CScaleStruct{}; + + // Global prefetch 1 + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + // Scales global load + a_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + b_scale_struct.template GlobalLoad<0>(num_loop_per_scale == 1); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + + // Local prefill 1 + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // Global prefetch 2, perform when at least 2 loops exist. + if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full) + { + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + } + + // Initialize C + c_thread_buf.Clear(); + + // Local prefetch 1 + block_sync_lds(); + + auto local_load_func = [&]() { + static_for<0, KRepeat, 1>{}([&](auto k0) { + static_for<0, MRepeat, 1>{}([&](auto m0) { + a_thread_copy_.Run(a_block_desc_k0_m0_m1_m2_k1, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_block_buf, + a_thread_desc_, + make_tuple(I0, m0, k0, I0, I0, I0, I0), + a_thread_buf); + }); + static_for<0, NRepeat, 1>{}([&](auto n0) { + b_thread_copy_.Run(b_block_desc_k0_n0_n1_n2_k1, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_block_buf, + b_thread_desc_, + make_tuple(I0, n0, k0, I0, I0, I0, I0), + b_thread_buf); + }); + }); + }; + + local_load_func(); + + __builtin_amdgcn_sched_barrier(0); + + // Main body, perform when at least 3 loops exist. + if constexpr(HasMainLoop) + { + index_t i = 0; + do + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + a_blockwise_copy.RunRead(a_grid_desc, a_grid_buf); + b_blockwise_copy.RunRead(b_grid_desc, b_grid_buf); + + a_blockwise_copy.MoveSrcSliceWindow(a_grid_desc, a_block_copy_step); + b_blockwise_copy.MoveSrcSliceWindow(b_grid_desc, b_block_copy_step); + + a_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>((i + 2) % num_loop_per_scale == 0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run(a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale + .GetVectorTypeReference(Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + block_sync_lds(); + + local_load_func(); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + + i += 1; + } while(i < (num_loop - 2)); + } + + // Pre-tail, perform when at least 2 loops exist. + if constexpr(TailNum == TailNumber::Even || TailNum == TailNumber::Full) + { + block_sync_lds(); + + a_blockwise_copy.RunWrite(a_block_desc, a_block_buf); + b_blockwise_copy.RunWrite(b_block_desc, b_block_buf); + + // No RunRead or MoveSrcSliceWindow here, already finished them all! + a_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + b_scale_struct.template GlobalLoad<0>(num_loop % num_loop_per_scale == 0); + + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + static_for<0, KInner, 1>{}([&](auto k_inner) { + vector_type a_thread_vec; + vector_type b_thread_vec; + + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + + c_scale_struct.Load(a_scale_struct, b_scale_struct); + block_sync_lds(); + + local_load_func(); + + HotLoopScheduler(); + __builtin_amdgcn_sched_barrier(0); + } + + // Tail, always perform. + { + static_for<0, MRepeat, 1>{}([&](auto m0) { + static_for<0, NRepeat, 1>{}([&](auto n0) { + static_for<0, NumScaleKBlock, 1>{}([&](auto kscale0) { + c_scale_struct.Clear(); + static_for<0, KRepeat / NumScaleKBlock, 1>{}([&](auto k0) { + vector_type a_thread_vec; + vector_type b_thread_vec; + static_for<0, KInner, 1>{}([&](auto k_inner) { + static_for<0, KPack / A_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + a_thread_vec.template AsType()(ik) = + a_thread_buf[Number{}, + m0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + static_for<0, KPack / B_KRow / KInner, 1>{}([&](auto ik) { + constexpr index_t kk = ik + k_inner * KPerWaveBlock; + constexpr index_t k_index = + kscale0 * (KRepeat / NumScaleKBlock) + k0; + b_thread_vec.template AsType()(ik) = + b_thread_buf[Number{}, + n0, + k_index, + I0, + I0, + I0, + Number{}))>{}]; + }); + + using wmma_input_type_a = + typename vector_type::type; + using wmma_input_type_b = + typename vector_type::type; + + wmma_gemm.Run( + a_thread_vec.template AsType(), + b_thread_vec.template AsType(), + c_scale_struct.c_thread_buf_per_scale.GetVectorTypeReference( + Number<0>{})); + }); + }); + c_scale_struct.template UpdateCThreadBuf(c_thread_buf); + }); + }); + }); + // Let's leak last WMMA block to epilogue region, cover the potential lds-shuffle + // latency + // __builtin_amdgcn_sched_barrier(0); + } + } + protected: using Base::a_thread_copy_; using Base::a_thread_desc_; diff --git a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp index 52a915de52..23b5178e3d 100644 --- a/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp @@ -105,6 +105,353 @@ struct DeviceGemmMultipleD_BlockScale_BPreshuffle : public BaseOperator virtual int GetPreShuffleParameters() = 0; }; +template +struct DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual int GetPreShuffleParameters() = 0; +}; + +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK in contexts where +// DeviceGemmMultipleD_BlockScale_BPreshuffle is expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleD_BlockScale_BPreshuffle and +// DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK is +/// that DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK::MakeArgumentPointer requires +// an additional parameter KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper + : public DeviceGemmMultipleD_BlockScale_BPreshuffle +{ + using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + p_a_scale, + p_b_scale, + a_element_op, + b_element_op, + cde_element_op, + 1); // KBatch + } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + int GetPreShuffleParameters() override { return p_op_->GetPreShuffleParameters(); } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + +// GEMM: +// input : A[M, K], B[K, N], +// input : D0[M, N], D1[M, N], ... +// output : E[M, N] +// C = a_op(A) * b_op(B) +// E = cde_op(C, D0, D1, ...) +// Assume: +// D0, D1, ... and E have the same layout +template +struct DeviceGemmMultipleD_ABScaleSplitK : public BaseOperator +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + virtual std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op, + index_t KBatch) = 0; + + virtual std::unique_ptr MakeInvokerPointer() = 0; + + virtual void SetKBatch(BaseArgument* arg, int KBatch) const = 0; +}; + +/// @brief Wrapper for backward compatibility that allows to use instances of +/// DeviceGemmMultipleD_ABScaleSplitK in contexts where DeviceGemmMultipleD_ABScale is +/// expected. +/// +/// @note The main area where it can be used is DeviceOperationInstanceFactory::GetInstances(). +/// The only difference between API of DeviceGemmMultipleD_ABScale and +/// DeviceGemmMultipleD_ABScaleSplitK is that +/// DeviceGemmMultipleD_ABScaleSplitK::MakeArgumentPointer requires a additional parameter +/// KBatch which is explicitly passed as 1 by this wrapper. +template +struct DeviceGemmMultipleD_ABScaleSplitKWrapper + : public DeviceGemmMultipleD_ABScale +{ + + using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK; + + static constexpr index_t NumDTensor = DsDataType::Size(); + +#ifndef __HIPCC_RTC__ + + explicit DeviceGemmMultipleD_ABScaleSplitKWrapper(std::unique_ptr p_op) + : p_op_(std::move(p_op)) + { + } + + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return p_op_->IsSupportedArgument(p_arg); + } + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + const ck::index_t M, + const ck::index_t N, + const ck::index_t K, + const ck::index_t StrideA, + const ck::index_t StrideB, + const std::array StrideDs, + const ck::index_t StrideE, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CDEElementwiseOperation cde_element_op) override + { + return p_op_->MakeArgumentPointer(p_a, + p_b, + p_ds, + p_e, + M, + N, + K, + StrideA, + StrideB, + StrideDs, + StrideE, + p_a_scale, + p_b_scale, + a_element_op, + b_element_op, + cde_element_op, + 1); // KBatch + } + + void SetKBatch(BaseArgument* arg, int KBatch) const override { p_op_->SetKBatch(arg, KBatch); } + + std::unique_ptr MakeInvokerPointer() override + { + return p_op_->MakeInvokerPointer(); + } + + std::string GetTypeString() const override { return p_op_->GetTypeString(); } + + private: + std::unique_ptr p_op_; + +#endif // __HIPCC_RTC__ +}; + } // namespace device } // namespace tensor_operation } // namespace ck diff --git a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp index 7752b334ed..ee1ddc494d 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_batched_gemm.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" @@ -93,7 +93,8 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) p_bs_grid_shift, karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset + c_batch_offset, - karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_k_split_offset, + karg.p_a_scale_grid, + karg.p_b_scale_grid + b_scale_batch_offset + splitk_batch_offset.scale_b_k_split_offset, p_shared, karg, karg.a_element_op, @@ -315,12 +316,13 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale }; // GridwiseGemm - using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_b_scale< + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< ALayout, BLayout, Tuple<>, // DsLayout CLayout, Tuple, + void, // AScaleType Tuple, BScaleDataType, AccDataType, @@ -332,6 +334,7 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale CElementwiseOperation, GemmSpec, BlockSize, + 0, // ScaleBlockM ScaleBlockN, ScaleBlockK, MPerBlock, @@ -405,7 +408,9 @@ struct DeviceBatchedGemm_Wmma_CShuffleV3_BScale std::array{StrideB_}, std::array{}, // StrideDs_ StrideC_, + 0, // StrideScaleA StrideScaleB_, + nullptr, p_b_scale_grid_, k_batch_, a_element_op_, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp new file mode 100644 index 0000000000..81a5d35e7c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp @@ -0,0 +1,362 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3 + : public DeviceGemmMultipleD_ABScaleSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< + ALayout, + BLayout, + DsLayout, + CLayout, + Tuple, + AScaleDataType, + Tuple, + BScaleDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB>; + + using Argument = typename GridwiseGemm::Argument; + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + DsDataType, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB>; + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + void SetKBatch(BaseArgument* base_arg, int KBatch) const override + { + auto& arg = *dynamic_cast(base_arg); + arg.KBatch = KBatch; + arg.KRead = GridwiseGemm::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm::CalculateBK0Padded(arg.K, KBatch); + } + + static auto MakeArgument(const ADataType* p_a, + const BDataType* p_b, + std::array p_ds, + CDataType* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const BScaleDataType* p_a_scale, + const BScaleDataType* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation cde_element_op, + index_t KBatch = 1) + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return Argument{std::array{p_a}, + std::array{p_b}, + p_ds, + p_c, + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + p_a_scale, + p_b_scale, + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_c, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch = 1) override + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return std::make_unique(std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_c), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemm_ABScale_Wmma_CShuffleV3" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< -__global__ void -#if CK_USE_LAUNCH_BOUNDS -__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) -#endif - kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) -{ -#if(defined(__gfx11__) || defined(__gfx12__)) -#if defined(__gfx11__) - // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions - using e_data_type = remove_cvref_t>; - if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && - (std::is_same_v || - std::is_same_v))) - { -#endif - constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< - typename GridwiseGemm::EpilogueCShuffle>(); - __shared__ char p_shared[LDS_size]; - - auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); - - const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); - const index_t k_id = blockIdx.z * num_k_per_block; - - auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; - - GridwiseGemm::template Run( - p_shared, splitk_batch_offset, karg, epilogue_args, k_id); - -#if defined(__gfx11__) - } -#endif -#else - ignore = karg; -#endif -} - -} // namespace ck - namespace ck { namespace tensor_operation { namespace device { @@ -202,270 +156,14 @@ struct DeviceGemmMultiD_Wmma_CShuffle_V3_BPreshuffle BlkGemmPipeSched, BlkGemmPipelineVer, ComputeTypeA, - ComputeTypeB>; + ComputeTypeB, + true>; // IsBPreshuffle // Invoker - struct Invoker : public BaseInvoker - { - /// @brief This function issues GPU kernel execution. - /// @param arg The GPU kernel arguments. - /// @param stream_config The HIP stream configuration helper structure. - /// @return The kernel's average execution time (if time measurement is - /// enabled). - float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) - { - if(stream_config.log_level_ > 0) - { - arg.Print(); - GridwiseGemm::BlockwiseGemmPipe::HotLoopInstList::Print(); - } - - if(!GridwiseGemm::CheckValidity(arg)) - { - throw std::runtime_error("wrong! GridwiseGemm has invalid setting"); - } - - index_t gdx, gdy, gdz; - std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(arg.M, arg.N, arg.KBatch); - - float ave_time = 0; - - index_t k_grain = arg.KBatch * KPerBlock; - index_t K_split = (arg.K + k_grain - 1) / k_grain * KPerBlock; - - const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K_split); - - const auto Run = [&](const auto& kernel) { - if(stream_config.flush_cache) - { - Argument arg_ = arg; - - const auto a_grid_desc_ak0_m_ak1 = GridwiseGemm::MakeAsGridDescriptor_AK0_M_AK1( - arg_.M, arg_.MPadded, arg_.K, arg_.KPadded, arg_.StrideAs, arg_.AK0); - const auto b_grid_desc_bk0_n_bk1 = GridwiseGemm::MakeBsGridDescriptor_BK0_N_BK1( - arg_.K, arg_.KPadded, arg_.N, arg_.NPadded, arg_.StrideBs, arg_.BK0); - - std::array size_as_buffers; - size_as_buffers[Number<0>{}] = - a_grid_desc_ak0_m_ak1[Number<0>{}].GetElementSpaceSize() * - sizeof(ADataType) / GridwiseGemm::APackedSize; - - std::array size_bs_buffers; - size_bs_buffers[Number<0>{}] = - b_grid_desc_bk0_n_bk1[Number<0>{}].GetElementSpaceSize() * - sizeof(BDataType) / GridwiseGemm::BPackedSize; - - const auto ds_grid_desc_m_n = GridwiseGemm::MakeDsGridDescriptor_M_N( - arg_.M, arg_.MPadded, arg_.N, arg_.NPadded, arg_.StrideDs); - - std::array size_ds_buffers; - static_for<0, GridwiseGemm::NumDTensor, 1>{}([&](auto i) { - using DDataType = remove_cvref_t>; - size_ds_buffers[i] = - ds_grid_desc_m_n[i].GetElementSpaceSize() * sizeof(DDataType); - }); - - ck::utility::RotatingMemWrapperMultiABD, - Tuple, - DsDataType> - rotating_mem(arg_, - stream_config.rotating_count, - size_as_buffers, - size_bs_buffers, - size_ds_buffers); - rotating_mem.Print(); - - auto run_flush_cache = [&]() { - // flush icache - ck::utility::flush_icache(); - // rotating mem - rotating_mem.Next(); - // clear c mem - if(arg_.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg_.p_e_grid, - 0, - arg_.M * arg_.N * sizeof(EDataType), - stream_config.stream_id_)); - }; - - ave_time = ck::utility::launch_and_time_kernel_with_preprocess( - stream_config, - run_flush_cache, - kernel, - dim3(gdx, gdy, gdz), - dim3(BlockSize), - 0, - arg_); - } - else - { - if(arg.KBatch > 1) - HIP_CHECK_ERROR(hipMemsetAsync(arg.p_e_grid, - 0, - arg.M * arg.N * sizeof(EDataType), - stream_config.stream_id_)); - - ave_time = launch_and_time_kernel( - stream_config, kernel, dim3(gdx, gdy, gdz), dim3(BlockSize), 0, arg); - } - }; - - constexpr index_t minimum_occupancy = []() { - if constexpr(BlkGemmPipeSched == BlockGemmPipelineScheduler::Interwave) - { - return 2; - } - else if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) - { - return (MPerBlock * NPerBlock / BlockSize <= 128) ? 2 : 1; - } - else - { - return 1; - } - }(); - - // ThreadwiseTensorSliceTransfer_v7r3 (used in ThreadGroupTensorSliceTransfer_v7r3) is - // currently implemented in such a way that all SrcScalarPerVectors must be the same, so - // if one of D matrices is column-major, then all SrcScalarPerVectors must be 1. On the - // other hand, Split K for 16-bit outputs uses packed atomics so ScalarPerVectors cannot - // be odd. - constexpr bool AtomicsImplementationExists = - !(std::is_same_v || std::is_same_v || - std::is_same_v) || - (CDEShuffleBlockTransferScalarPerVectors{}[0] % 2 == 0); - - if(has_main_k_block_loop) - { - // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - if constexpr(AtomicsImplementationExists) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - true, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) - { - if constexpr(AtomicsImplementationExists) - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::AtomicAdd, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - else - { - if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Odd>; - Run(kernel); - } - else - { - const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< - GridwiseGemm, - false, - InMemoryDataOperationEnum::Set, - minimum_occupancy, - TailNumber::Even>; - Run(kernel); - } - } - } - } - - return ave_time; - } - - // polymorphic - float Run(const BaseArgument* p_arg, - const StreamConfig& stream_config = StreamConfig{}) override - { - return Run(*dynamic_cast(p_arg), stream_config); - } - }; + using Invoker = typename DeviceGemmCommon::Invoker; static bool IsSupportedArgument(const Argument& arg) { - if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) - { - return false; - } return DeviceGemmCommon::IsSupportedArgument(arg); } diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp new file mode 100644 index 0000000000..1b1a1fcc6c --- /dev/null +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp @@ -0,0 +1,360 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include + +#include "ck/utility/common_header.hpp" +#include "ck/tensor_description/tensor_descriptor.hpp" +#include "ck/tensor_description/tensor_descriptor_helper.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d_ab_scale.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" +#include "ck/host_utility/device_prop.hpp" +#include "ck/host_utility/kernel_launch.hpp" +#include "ck/host_utility/flush_cache.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { + +template +struct DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle + : public DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK +{ + static constexpr index_t NumDTensor = DsDataType::Size(); + + using AScaleLayout = tensor_layout::gemm::ColumnMajor; + using BScaleLayout = BLayout; + + // GridwiseGemm + using GridwiseGemm = GridwiseGemm_wmma_cshuffle_v3_ab_scale< + ALayout, + BLayout, + DsLayout, + CLayout, + Tuple, + AScaleDataType, + Tuple, + BScaleDataType, + AccDataType, + CShuffleDataType, + DsDataType, + CDataType, + AElementwiseOperation, + BElementwiseOperation, + CElementwiseOperation, + GemmSpec, + BlockSize, + ScaleBlockM, + ScaleBlockN, + ScaleBlockK, + MPerBlock, + NPerBlock, + KPerBlock, + AK1, + BK1, + MPerWmma, + NPerWmma, + MRepeat, + NRepeat, + ABlockTransferThreadClusterLengths_AK0_M_AK1, + ABlockTransferThreadClusterArrangeOrder, + ABlockTransferSrcAccessOrder, + ABlockTransferSrcVectorDim, + ABlockTransferSrcScalarPerVector, + ABlockTransferDstScalarPerVector_AK1, + false, + ABlockLdsExtraM, + BBlockTransferThreadClusterLengths_BK0_N_BK1, + BBlockTransferThreadClusterArrangeOrder, + BBlockTransferSrcAccessOrder, + BBlockTransferSrcVectorDim, + BBlockTransferSrcScalarPerVector, + BBlockTransferDstScalarPerVector_BK1, + false, + BBlockLdsExtraN, + CShuffleMRepeatPerShuffle, + CShuffleNRepeatPerShuffle, + CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + PermuteA, + PermuteB, + true, + AScaleLayout, + BScaleLayout>; + + using Argument = typename GridwiseGemm::Argument; + int GetPreShuffleParameters() override { return NPerWmma; } + + using DeviceGemmCommon = + DeviceGemm_Wmma_CShuffleV3_Common, + Tuple, + DsDataType, + CDataType, + MPerBlock, + NPerBlock, + KPerBlock, + BlockSize, + AK1, + BK1, + GemmSpec, + CShuffleBlockTransferScalarPerVectors, + BlkGemmPipeSched, + BlkGemmPipelineVer, + ComputeTypeA, + ComputeTypeB, + true>; // IsBPreshuffle + + // Invoker + using Invoker = typename DeviceGemmCommon::Invoker; + + static bool IsSupportedArgument(const Argument& arg) + { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + + return DeviceGemmCommon::IsSupportedArgument(arg); + } + + // polymorphic + bool IsSupportedArgument(const BaseArgument* p_arg) override + { + return IsSupportedArgument(*dynamic_cast(p_arg)); + } + + static auto MakeArgument(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation cde_element_op, + index_t KBatch) + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return Argument{std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_e), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + cde_element_op}; + } + + static auto MakeInvoker() { return Invoker{}; } + + // polymorphic + std::unique_ptr + MakeArgumentPointer(const void* p_a, + const void* p_b, + std::array p_ds, + void* p_e, + index_t M, + index_t N, + index_t K, + index_t StrideA, + index_t StrideB, + const std::array StrideDs, + index_t StrideC, + const void* p_a_scale, + const void* p_b_scale, + AElementwiseOperation a_element_op, + BElementwiseOperation b_element_op, + CElementwiseOperation c_element_op, + index_t KBatch) override + { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + + return std::make_unique(std::array{p_a}, + std::array{p_b}, + p_ds, + static_cast(p_e), + M, + N, + K, + std::array{StrideA}, + std::array{StrideB}, + StrideDs, + StrideC, + StrideScaleA, + StrideScaleB, + static_cast(p_a_scale), + static_cast(p_b_scale), + KBatch, + a_element_op, + b_element_op, + c_element_op); + } + + // polymorphic + std::unique_ptr MakeInvokerPointer() override + { + return std::make_unique(Invoker{}); + } + + // polymorphic + std::string GetTypeString() const override + { + auto str = std::stringstream(); + + std::map BlkGemmPipelineSchedulerToString{ + {BlockGemmPipelineScheduler::Intrawave, "Intrawave"}, + {BlockGemmPipelineScheduler::Interwave, "Interwave"}}; + + std::map BlkGemmPipelineVersionToString{ + {BlockGemmPipelineVersion::v1, "v1"}, + {BlockGemmPipelineVersion::v2, "v2"}, + {BlockGemmPipelineVersion::v3, "v3"}, + {BlockGemmPipelineVersion::v4, "v4"}, + {BlockGemmPipelineVersion::v5, "v5"}}; + + // clang-format off + str << "DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle" + << "<" + << getGemmSpecializationString(GemmSpec) << ", " + << std::string(ALayout::name)[0] + << std::string(BLayout::name)[0] + << std::string(CLayout::name)[0] + << ">" + << " BlkSize: " + << BlockSize << ", " + << "BlkTile: " + << MPerBlock<<"x"< 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else { const auto kernel = kernel_gemm_xdl_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } } else if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) { - const auto kernel = - kernel_gemm_xdl_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } + else + { + const auto kernel = + kernel_gemm_xdl_cshuffle_v3; + Run(kernel); + } } } } @@ -315,6 +350,20 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 { auto& arg = *dynamic_cast(base_arg); arg.KBatch = KBatch; + if(get_warp_size() == 64) + { + arg.KRead = GridwiseGemm64::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm64::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm64::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm64::CalculateBK0Padded(arg.K, KBatch); + } + else + { + arg.KRead = GridwiseGemm32::CalculateKRead(arg.K, KBatch); + arg.KPadded = GridwiseGemm32::CalculateKPadded(arg.K, KBatch); + arg.AK0 = GridwiseGemm32::CalculateAK0Padded(arg.K, KBatch); + arg.BK0 = GridwiseGemm32::CalculateBK0Padded(arg.K, KBatch); + } } static constexpr bool IsValidCompilationParameter() @@ -325,6 +374,13 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 static bool IsSupportedArgument(const Argument& arg) { + // with splitk the implementation doesn't work + // when KRead % ScaleBlockK != 0, independently of K padding + if(arg.KBatch > 1 && arg.KRead % ScaleBlockK != 0) + { + return false; + } + if(!ck::is_xdl_wmma_supported()) { return false; @@ -385,6 +441,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + return Argument{static_cast(p_a), static_cast(p_b), p_ds, @@ -396,6 +460,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 StrideB, StrideDs, StrideC, + StrideScaleA, + StrideScaleB, static_cast(p_a_scale), static_cast(p_b_scale), 1, @@ -425,6 +491,14 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 BElementwiseOperation b_element_op, CElementwiseOperation c_element_op) override { + index_t StrideScaleA = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(M, ScaleBlockM); + + index_t StrideScaleB = ck::is_same_v + ? math::integer_divide_ceil(K, ScaleBlockK) + : math::integer_divide_ceil(N, ScaleBlockN); + return std::make_unique(static_cast(p_a), static_cast(p_b), p_ds, @@ -436,6 +510,8 @@ struct DeviceGemmMultiD_ABScale_Xdl_CShuffle_V3 StrideB, StrideDs, StrideC, + StrideScaleA, + StrideScaleB, static_cast(p_a_scale), static_cast(p_b_scale), 1, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp index e824fcc9dd..491f3a7dac 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_b_scale.hpp @@ -12,7 +12,7 @@ #include "ck/tensor_operation/gpu/device/tensor_layout.hpp" #include "ck/tensor_operation/gpu/device/device_gemm_v2.hpp" #include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" -#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp" +#include "ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp" #include "ck/host_utility/device_prop.hpp" #include "ck/host_utility/kernel_launch.hpp" #include "ck/host_utility/flush_cache.hpp" @@ -86,12 +86,13 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale, // DsLayout CLayout, Tuple, + void, // AScaleType Tuple, BScaleDataType, AccDataType, @@ -103,6 +104,7 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{StrideB}, std::array{}, // StrideDs_ StrideC, + 0, // StrideScaleA StrideScaleB, + nullptr, p_b_scale, KBatch, a_element_op, @@ -245,7 +249,9 @@ struct DeviceGemm_BScale_Wmma_CShuffleV3 : public DeviceGemmV2BScale{StrideB}, std::array{}, // StrideDs_ StrideC, + 0, // StrideScaleA StrideScaleB, + nullptr, // p_a_scale static_cast(p_b_scale), KBatch, a_element_op, diff --git a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp index 6706365fb7..e96ec58cba 100644 --- a/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma_cshuffle_v3_common.hpp @@ -38,7 +38,8 @@ template + typename ComputeTypeB, + bool IsBPreShuffled = false> struct DeviceGemm_Wmma_CShuffleV3_Common { @@ -189,61 +190,174 @@ struct DeviceGemm_Wmma_CShuffleV3_Common if(has_main_k_block_loop) { // Tail number always full - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || - BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) + if constexpr(IsBPreShuffled) { - if(arg.KBatch > 1) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - if constexpr(AtomicsImplementationExists) + if(arg.KBatch > 1) { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); + if constexpr(AtomicsImplementationExists) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } } - } - else - { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); } } else { - // TODO: Implement - } - } - else - { - // Tail number always 1 - if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) - { - if(arg.KBatch > 1) + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1 || + BlkGemmPipelineVer == BlockGemmPipelineVersion::v3) { - if constexpr(AtomicsImplementationExists) + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + const auto kernel = kernel_gemm_wmma_cshuffle_v3< + GridwiseGemm, + true, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + } + else { const auto kernel = kernel_gemm_wmma_cshuffle_v3; Run(kernel); } } - else + } + } + else + { + if constexpr(IsBPreShuffled) + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) { - const auto kernel = - kernel_gemm_wmma_cshuffle_v3; - Run(kernel); + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == + TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + else + { + if(GridwiseGemm::CalculateKBlockLoopTailNum(K_split) == TailNumber::Odd) + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Odd>; + Run(kernel); + } + else + { + const auto kernel = kernel_gemm_b_preshuffle_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::Set, + minimum_occupancy, + TailNumber::Even>; + Run(kernel); + } + } + } + } + else + { + if constexpr(BlkGemmPipelineVer == BlockGemmPipelineVersion::v1) + { + if(arg.KBatch > 1) + { + if constexpr(AtomicsImplementationExists) + { + const auto kernel = kernel_gemm_wmma_cshuffle_v3< + GridwiseGemm, + false, + InMemoryDataOperationEnum::AtomicAdd, + minimum_occupancy>; + Run(kernel); + } + } + else + { + const auto kernel = + kernel_gemm_wmma_cshuffle_v3; + Run(kernel); + } } } } @@ -299,6 +413,14 @@ struct DeviceGemm_Wmma_CShuffleV3_Common return false; } + if constexpr(IsBPreShuffled) + { + if(arg.N % NPerBlock != 0 || arg.K % KPerBlock != 0) + { + return false; + } + } + return GridwiseGemm::CheckValidity(arg); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp index 4526eb3186..69f8f44390 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_ab_transfer_thread_tiles.hpp @@ -388,11 +388,11 @@ struct ABTransferThreadTiles // 1 - MNRepeat - K0 / KRow - MNWaves - KRow - MNPerWmma - K1 return transform_tensor_descriptor( BlockDesc{}, - make_tuple( - make_unmerge_transform(make_tuple(Number{}, KRow, Number<1>{})), - make_unmerge_transform( - make_tuple(Number{}, Number{}, Number{})), - make_pass_through_transform(Number{})), + make_tuple(make_unmerge_transform(make_tuple( + Number{}, KRow, Number{})), + make_unmerge_transform(make_tuple( + Number{}, Number{}, Number{})), + make_pass_through_transform(Number{})), make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}), make_tuple(Sequence<2, 4, 0>{}, Sequence<1, 3, 5>{}, Sequence<6>{})); } diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp index f58f67dc6b..121ca258be 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_gemm_wmma_cshuffle_v3.hpp @@ -895,8 +895,9 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 c_thread_buf.Clear(); // Empty BScale struct for the blockwise pipeline. - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; + using ABScale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = ABScale{}; + auto b_scale_struct = ABScale{}; /*******************************************************************************/ // @@ -919,6 +920,7 @@ struct GridwiseBatchedGemmGemm_wmma_cshuffle_v3 b0_block_buf, b0_block_slice_copy_step, acc0_thread_buf, + a_scale_struct, b_scale_struct, KBlockMainLoop, 1); // num_k_block_per_scale diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp index e55ac807c5..fea0102337 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3.hpp @@ -618,8 +618,9 @@ struct GridwiseGemm_wmma_cshuffle_v3 __builtin_amdgcn_readfirstlane(block_work_idx[Number{}]); // BScale struct (Empty) - using BScale = typename BlockwiseGemmPipe::Empty; - auto b_scale_struct = BScale{}; + using Scale = typename BlockwiseGemmPipe::Empty; + auto a_scale_struct = Scale{}; + auto b_scale_struct = Scale{}; const index_t num_k_block_per_scale = GetKBlockPerScale(); @@ -627,6 +628,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 decltype(bs_grid_desc_bk0_n_bk1), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(a_scale_struct), decltype(b_scale_struct), decltype(epilogue_args), HasMainKBlockLoop, @@ -646,6 +648,7 @@ struct GridwiseGemm_wmma_cshuffle_v3 block_m_id, block_n_id, num_k_block_per_scale, + a_scale_struct, b_scale_struct, epilogue_args, k_id); diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp similarity index 58% rename from include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp rename to include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp index 8684731c96..ac5b7dd0c4 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_b_scale.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_ab_scale.hpp @@ -23,6 +23,7 @@ template -struct GridwiseGemm_wmma_cshuffle_v3_b_scale + BlockGemmPipelineScheduler BlkGemmPipeSched, + BlockGemmPipelineVersion BlkGemmPipelineVer, + typename ComputeTypeA, + typename ComputeTypeB, + bool PermuteA, + bool PermuteB, + bool IsBPreShuffled = false, + typename AScaleLayout = ALayout, + typename BScaleLayout = BLayout> +struct GridwiseGemm_wmma_cshuffle_v3_ab_scale : GridwiseGemm_wmma_cshuffle_v3_base< ALayout, BLayout, @@ -123,7 +128,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale ComputeTypeB, PermuteA, PermuteB, - false, + IsBPreShuffled, true> { using Base = GridwiseGemm_wmma_cshuffle_v3_base< @@ -177,7 +182,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale ComputeTypeB, PermuteA, PermuteB, - false, + IsBPreShuffled, true>; using Base::I0; @@ -233,6 +238,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs_, std::array StrideDs_, index_t StrideE_, + index_t StrideScaleA_, index_t StrideScaleB_, index_t KBatch_) : M{M_}, @@ -242,6 +248,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale StrideBs{StrideBs_}, StrideDs{StrideDs_}, StrideE{StrideE_}, + StrideScaleA{StrideScaleA_}, StrideScaleB{StrideScaleB_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, @@ -251,7 +258,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale AK0{CalculateAK0Padded(K_, KBatch_)}, BK0{CalculateBK0Padded(K_, KBatch_)}, MBlock{CalculateMBlock(M_)}, - NBlock{CalculateNBlock(N_)} + NBlock{CalculateNBlock(N_)}, + Kt{K_} { } @@ -275,11 +283,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale }); std::cout << " }, "; } - std::cout << "SE:" << StrideE << ", " << "SScaleB:" << StrideScaleB << ", " - << "MP:" << MPadded << ", " << "NP:" << NPadded << ", " << "KRead:" << KRead - << ", " << "KP:" << KPadded << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 - << ", " << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" - << std::endl; + std::cout << "SE:" << StrideE << ", " << "SScaleA:" << StrideScaleA << ", " + << "SScaleB:" << StrideScaleB << ", " << "MP:" << MPadded << ", " + << "NP:" << NPadded << ", " << "KRead:" << KRead << ", " << "KP:" << KPadded + << ", " << "AK0:" << AK0 << ", " << "BK0:" << BK0 << ", " + << "MBlock: " << MBlock << ", " << "NBlock: " << NBlock << "}" << std::endl; } index_t M; @@ -289,6 +297,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs; std::array StrideDs; index_t StrideE; + index_t StrideScaleA; index_t StrideScaleB; index_t KBatch; index_t MPadded; @@ -299,6 +308,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale index_t BK0; index_t MBlock; index_t NBlock; + index_t Kt; }; // Argument @@ -315,7 +325,9 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array StrideBs_, std::array StrideDs_, index_t StrideE_, + index_t StrideScaleA_, index_t StrideScaleB_, + const AScaleType* p_a_scale_grid_, const BScaleType* p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, @@ -329,12 +341,14 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale StrideBs_, StrideDs_, StrideE_, + StrideScaleA_, StrideScaleB_, k_batch_}, p_as_grid{}, p_bs_grid{}, p_ds_grid{}, p_e_grid{p_e_grid_}, + p_a_scale_grid{p_a_scale_grid_}, p_b_scale_grid{p_b_scale_grid_}, a_element_op{a_element_op_}, b_element_op{b_element_op_}, @@ -379,6 +393,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale DsGridPointer p_ds_grid; EDataType* p_e_grid; + const AScaleType* p_a_scale_grid; const BScaleType* p_b_scale_grid; const AElementwiseOperation a_element_op; const BElementwiseOperation b_element_op; @@ -407,34 +422,52 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale [&](auto i) { a_k_split_offset[i] = k_id * karg.KRead * karg.StrideAs[i]; }); } - if constexpr(is_same_v) + if constexpr(IsBPreShuffled) { - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { b_k_split_offset[i] = 0; }); } - else if constexpr(is_same_v) + else { - if constexpr(!PermuteB) + if constexpr(is_same_v) { - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); + static_for<0, NumBTensor, 1>{}([&](auto i) { + b_k_split_offset[i] = k_id * karg.KRead * karg.StrideBs[i]; + }); } - else + else if constexpr(is_same_v) { - const int k0_offset = karg.KRead * karg.N; - static_for<0, NumBTensor, 1>{}( - [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); + if constexpr(!PermuteB) + { + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * karg.KRead / BPackedSize; }); + } + else + { + const int k0_offset = karg.KRead * karg.N; + static_for<0, NumBTensor, 1>{}( + [&](auto i) { b_k_split_offset[i] = k_id * k0_offset / BPackedSize; }); + } } } - // Calculate B scale offset - if constexpr(is_same_v) + // Calculate A scale offset + if constexpr(is_same_v) { - scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideB; + scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK); } - else if constexpr(is_same_v) + else if constexpr(is_same_v) { - scale_k_split_offset = k_id * (karg.KRead / ScaleBlockK); + scale_a_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + scale_b_k_split_offset = k_id * (karg.KRead / ScaleBlockK); } if(k_id < karg.KBatch - 1) @@ -458,77 +491,225 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale std::array a_k_split_offset; std::array b_k_split_offset; - index_t scale_k_split_offset; // New member for scale matrix offset + index_t scale_a_k_split_offset; // A scale matrix offset + index_t scale_b_k_split_offset; // B scale matrix offset index_t c_reduce_offset; }; using BlockwiseGemmPipe = typename Base::BlockwiseGemmPipe; // return block_id to C matrix tile idx (m0, n0) mapping - // if arch = gfx942 using Block2CTileMap = BlockToCTileMap_Grouped_M00_N0_M01Adapt<8, MPerBlock, NPerBlock>; - // using Block2CTileMap = BlockToCTileMap_3DGrid_KSplit; - template - __device__ static auto MakeBScale(const BScaleGridDesc_BN_AK& b_scale_grid_desc_bn_ak, - const BScaleType* p_b_scale_grid, - index_t block_n_id) + __device__ static constexpr auto + MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA) { - const auto b_scale_grid_buf = make_dynamic_buffer( - p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + const auto BM = math::integer_divide_ceil(M, ScaleBlockM); + const auto BK = math::integer_divide_ceil(K, ScaleBlockK); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA)); + } + } - static constexpr auto wmma = - WmmaSelector{}; - static constexpr auto KPerThread = wmma.selected_wmma.k_per_wmma; + template + __device__ static auto + MakeAScale(const Problem& problem, const AScaleType* p_a_scale_grid, index_t block_m_id) + { + if constexpr(ck::is_same_v) + { + using AScale = typename BlockwiseGemmPipe::Empty; + return AScale{}; + } + else + { +#if defined(__gfx11__) + // TODO: remove this restriction + static_assert(ScaleBlockM >= MPerWmma, + "ScaleBlockM must be greater equal than MPerWmma"); +#endif + static_assert( + ScaleBlockK >= + WmmaSelector:: + selected_wmma.k_per_wmma, + "ScaleBlockK must be greater equal than KPerWmma"); - static constexpr auto ScaleSliceSizeN = NRepeat; - static constexpr auto ScaleSliceSizeK = (KPerThread + ScaleBlockK - 1) / ScaleBlockK; + const auto a_scale_grid_desc_am_ak = + MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA); - constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( - make_tuple(Number{}, Number{})); + const auto a_scale_grid_buf = make_dynamic_buffer( + p_a_scale_grid, a_scale_grid_desc_am_ak.GetElementSpaceSize()); - constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + constexpr auto wmma = + WmmaSelector{}; + constexpr auto RegSizePerWmmaFull = + wmma.selected_wmma.num_acc_vgprs_per_wave * wmma.selected_wmma.acc_pack_number; + constexpr auto RegSizePerWmma = + math::integer_divide_ceil(RegSizePerWmmaFull, ScaleBlockM); - auto b_thread_offset_n = get_thread_local_1d_id() % NPerWmma + - (get_thread_local_1d_id() / 32) % NWaves * NPerWmma; - auto b_thread_offset_k = (get_thread_local_1d_id() % 32) / NPerWmma * KPerThread; + constexpr index_t MWaves = MPerBlock / (MRepeat * MPerWmma); + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); - auto b_scale_thread_copy = - ThreadwiseTensorSliceTransfer_v2, - Sequence<0, 1>, - 1, - ScaleSliceSizeK, - 1, - false>( - b_scale_grid_desc_bn_ak, - make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, - b_thread_offset_k / ScaleBlockK)); + constexpr auto ScaleSliceSizeM = + ScaleBlockM < MPerWmma ? MRepeat * RegSizePerWmma + : math::integer_divide_ceil(MPerBlock, ScaleBlockM); + constexpr auto ScaleSliceStrideM = + math::integer_divide_ceil(MWaves * MPerWmma, ScaleBlockM); + constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); - auto b_scale_thread_buf = make_static_buffer( - b_scale_thread_desc.GetElementSpaceSize()); + constexpr auto a_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); - using BScale = - typename BlockwiseGemmPipe::template BScale; + auto a_thread_offset_m = + ((get_thread_local_1d_id() % 32) / MPerWmma * RegSizePerWmma) / + math::integer_divide_ceil(ScaleBlockM, RegSizePerWmmaFull) + + (get_thread_local_1d_id() / 32) / NWaves * MPerWmma / ScaleBlockM; - return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + constexpr index_t VectorDim = + is_same::value ? 0 : 1; + constexpr index_t VectorSize = + is_same::value ? RegSizePerWmma + : ScaleSliceSizeK; + + auto a_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + VectorDim, + VectorSize, + 1, + true>( + a_scale_grid_desc_am_ak, + make_multi_index(block_m_id * MPerBlock / ScaleBlockM + a_thread_offset_m, 0)); + + auto a_scale_thread_buf = make_static_buffer( + a_scale_thread_desc.GetElementSpaceSize()); + + using AScale = + typename BlockwiseGemmPipe::template ABScale; + + return AScale{a_scale_grid_desc_am_ak, a_scale_thread_copy, a_scale_grid_buf}; + } + } + + __device__ static constexpr auto + MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB) + { + const auto BN = math::integer_divide_ceil(N, ScaleBlockN); + const auto BK = math::integer_divide_ceil(K, ScaleBlockK); + if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1)); + } + else if constexpr(is_same::value) + { + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB)); + } + } + + template + __device__ static auto + MakeBScale(const Problem& problem, const BScaleType* p_b_scale_grid, index_t block_n_id) + { + if constexpr(ck::is_same_v) + { + using BScale = typename BlockwiseGemmPipe::Empty; + return BScale{}; + } + else + { + static_assert( + ScaleBlockK >= + WmmaSelector:: + selected_wmma.k_per_wmma, + "ScaleBlockK must be greater equal than KPerWmma"); + + const auto b_scale_grid_desc_bn_ak = + MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB); + + const auto b_scale_grid_buf = make_dynamic_buffer( + p_b_scale_grid, b_scale_grid_desc_bn_ak.GetElementSpaceSize()); + + constexpr index_t NWaves = NPerBlock / (NRepeat * NPerWmma); + + constexpr auto ScaleSliceSizeN = + ScaleBlockN < NPerWmma ? NRepeat + : math::integer_divide_ceil(NPerBlock, ScaleBlockN); + constexpr auto ScaleSliceStrideN = + math::integer_divide_ceil(NWaves * NPerWmma, ScaleBlockN); + constexpr auto ScaleSliceSizeK = math::integer_divide_ceil(KPerBlock, ScaleBlockK); + + constexpr auto b_scale_thread_desc = make_naive_tensor_descriptor_packed( + make_tuple(Number{}, Number{})); + + auto b_thread_offset_n = (get_thread_local_1d_id() % NPerWmma + + (get_thread_local_1d_id() / 32) % NWaves * NPerWmma) / + ScaleBlockN; + + constexpr index_t VectorDim = + is_same::value ? 0 : 1; + constexpr index_t VectorSize = + is_same::value ? 1 : ScaleSliceSizeK; + + auto b_scale_thread_copy = + ThreadwiseTensorSliceTransfer_v2, + Sequence<0, 1>, + VectorDim, + VectorSize, + 1, + true>( + b_scale_grid_desc_bn_ak, + make_multi_index(block_n_id * NPerBlock / ScaleBlockN + b_thread_offset_n, 0)); + + auto b_scale_thread_buf = make_static_buffer( + b_scale_thread_desc.GetElementSpaceSize()); + + using BScale = + typename BlockwiseGemmPipe::template ABScale; + + return BScale{b_scale_grid_desc_bn_ak, b_scale_thread_copy, b_scale_grid_buf}; + } } __device__ static index_t GetKBlockPerScale() { - return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + if constexpr(ck::is_same_v && ck::is_same_v) + { + return 0; + } + else + { + return (ScaleBlockK + KPerBlock - 1) / KPerBlock; + } } template ( @@ -562,12 +746,6 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( e_grid_desc_m_n, problem.MBlock, problem.NBlock); - // B Scale grid - const auto b_scale_grid_desc_bn_ak = make_naive_tensor_descriptor( - make_tuple(math::integer_divide_ceil(problem.N, ScaleBlockN), - math::integer_divide_ceil(problem.K, ScaleBlockK)), - make_tuple(problem.StrideScaleB, 1)); - // divide block work by [M, N] const auto block_2_ctile_map = Block2CTileMap{problem.M, problem.N, 4}; @@ -585,8 +763,11 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale const index_t block_m_id = __builtin_amdgcn_readfirstlane(block_work_idx[I0]); const index_t block_n_id = __builtin_amdgcn_readfirstlane(block_work_idx[I1]); + // AScale struct + auto a_scale_struct = MakeAScale<1>(problem, p_a_scale_grid, block_m_id); + // BScale struct - auto b_scale_struct = MakeBScale<1>(b_scale_grid_desc_bn_ak, p_b_scale_grid, block_n_id); + auto b_scale_struct = MakeBScale<1>(problem, p_b_scale_grid, block_n_id); const index_t num_k_block_per_scale = GetKBlockPerScale(); @@ -594,6 +775,7 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale decltype(bs_grid_desc_bk0_n_bk1), decltype(ds_grid_desc_mblock_mperblock_nblock_nperblock), decltype(e_grid_desc_mblock_mperblock_nblock_nperblock), + decltype(a_scale_struct), decltype(b_scale_struct), decltype(epilogue_args), HasMainKBlockLoop, @@ -613,8 +795,10 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale block_m_id, block_n_id, num_k_block_per_scale, + a_scale_struct, b_scale_struct, - epilogue_args); + epilogue_args, + k_id); } // NOTE: Wrapper function to have __global__ function in common @@ -626,7 +810,8 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale __device__ static void Run(void* p_shared, const SplitKBatchOffset& splitk_batch_offset, Argument& karg, - EpilogueArgument& epilogue_args) + EpilogueArgument& epilogue_args, + const index_t k_id = 0) { // shift A matrices pointer for splitk AsGridPointer p_as_grid_splitk; @@ -644,18 +829,40 @@ struct GridwiseGemm_wmma_cshuffle_v3_b_scale splitk_batch_offset.b_k_split_offset[i]; }); + const AScaleType* p_a_scale_grid_ptr; + if constexpr(ck::is_same_v) + { + p_a_scale_grid_ptr = karg.p_a_scale_grid; + } + else + { + p_a_scale_grid_ptr = karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset; + } + + const BScaleType* p_b_scale_grid_ptr; + if constexpr(ck::is_same_v) + { + p_b_scale_grid_ptr = karg.p_b_scale_grid; + } + else + { + p_b_scale_grid_ptr = karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset; + } + Run( p_as_grid_splitk, p_bs_grid_splitk, karg.p_ds_grid, karg.p_e_grid + splitk_batch_offset.c_reduce_offset, - karg.p_b_scale_grid + splitk_batch_offset.scale_k_split_offset, + p_a_scale_grid_ptr, + p_b_scale_grid_ptr, p_shared, karg, karg.a_element_op, karg.b_element_op, karg.cde_element_op, - epilogue_args); + epilogue_args, + k_id); } }; diff --git a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp index 04d1d98448..81aa1ac986 100644 --- a/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp +++ b/include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_cshuffle_v3_common.hpp @@ -69,6 +69,48 @@ __launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) #endif } +template +__global__ void +#if CK_USE_LAUNCH_BOUNDS +__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, MinimumOccupancy) +#endif + kernel_gemm_b_preshuffle_wmma_cshuffle_v3(typename GridwiseGemm::Argument karg) +{ +#if(defined(__gfx11__) || defined(__gfx12__)) +#if defined(__gfx11__) + // gfx11 does not support *_atomic_pk_add_f16/bf16 instructions + using e_data_type = remove_cvref_t>; + if constexpr(!(EGlobalMemoryDataOperation == InMemoryDataOperationEnum::AtomicAdd && + (std::is_same_v || + std::is_same_v))) + { +#endif + constexpr index_t LDS_size = GridwiseGemm::template GetSharedMemoryNumberOfByte< + typename GridwiseGemm::EpilogueCShuffle>(); + __shared__ char p_shared[LDS_size]; + + auto splitk_batch_offset = typename GridwiseGemm::SplitKBatchOffset(karg, blockIdx.z); + + const index_t num_k_per_block = math::integer_divide_ceil(karg.K, GridwiseGemm::KPack); + const index_t k_id = blockIdx.z * num_k_per_block; + + auto epilogue_args = typename GridwiseGemm::EpilogueCShuffle{}; + + GridwiseGemm::template Run( + p_shared, splitk_batch_offset, karg, epilogue_args, k_id); + +#if defined(__gfx11__) + } +#endif +#else + ignore = karg; +#endif +} + template ( - karg.p_a_grid, - karg.p_b_grid, + karg.p_a_grid + splitk_batch_offset.a_k_split_offset, + karg.p_b_grid + splitk_batch_offset.b_k_split_offset, karg.p_ds_grid, karg.p_c_grid, - karg.p_a_scale_grid, - karg.p_b_scale_grid, + karg.p_a_scale_grid + splitk_batch_offset.scale_a_k_split_offset, + karg.p_b_scale_grid + splitk_batch_offset.scale_b_k_split_offset, p_shared, karg, karg.a_element_op, @@ -405,31 +407,33 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 } } - __host__ __device__ static constexpr auto MakeAScaleGridDesciptor_M_K(index_t M, index_t K) + __host__ __device__ static constexpr auto + MakeAScaleGridDesciptor_M_K(index_t M, index_t K, index_t StrideScaleA) { const auto BM = math::integer_divide_ceil(M, ScaleBlockM); const auto BK = math::integer_divide_ceil(K, ScaleBlockK); if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(BK, I1)); + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(StrideScaleA, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, BM)); + return make_naive_tensor_descriptor(make_tuple(BM, BK), make_tuple(I1, StrideScaleA)); } } - __host__ __device__ static constexpr auto MakeBScaleGridDesciptor_N_K(index_t N, index_t K) + __host__ __device__ static constexpr auto + MakeBScaleGridDesciptor_N_K(index_t N, index_t K, index_t StrideScaleB) { const auto BN = math::integer_divide_ceil(N, ScaleBlockN); const auto BK = math::integer_divide_ceil(K, ScaleBlockK); if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(BK, I1)); + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(StrideScaleB, I1)); } else if constexpr(is_same::value) { - return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, BN)); + return make_naive_tensor_descriptor(make_tuple(BN, BK), make_tuple(I1, StrideScaleB)); } } @@ -548,6 +552,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB_, std::array StrideDs_, index_t StrideC_, + index_t StrideScaleA_, + index_t StrideScaleB_, index_t KBatch_) : M{M_}, N{N_}, @@ -556,6 +562,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 StrideB{StrideB_}, StrideDs{StrideDs_}, StrideC{StrideC_}, + StrideScaleA{StrideScaleA_}, + StrideScaleB{StrideScaleB_}, KBatch{KBatch_}, MPadded{CalculateMPadded(M_)}, NPadded{CalculateNPadded(N_)}, @@ -585,7 +593,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB; std::array StrideDs; index_t StrideC; - + index_t StrideScaleA; + index_t StrideScaleB; index_t KBatch; index_t MPadded; index_t NPadded; @@ -611,13 +620,24 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t StrideB_, std::array StrideDs_, index_t StrideC_, + index_t StrideScaleA_, + index_t StrideScaleB_, const AScaleType* p_a_scale_grid_, const BScaleType* p_b_scale_grid_, index_t k_batch_, AElementwiseOperation a_element_op_, BElementwiseOperation b_element_op_, CElementwiseOperation c_element_op_) - : Problem{M_, N_, K_, StrideA_, StrideB_, StrideDs_, StrideC_, k_batch_}, + : Problem{M_, + N_, + K_, + StrideA_, + StrideB_, + StrideDs_, + StrideC_, + StrideScaleA_, + StrideScaleB_, + k_batch_}, p_a_grid{p_a_grid_}, p_b_grid{p_b_grid_}, p_ds_grid{}, @@ -673,6 +693,28 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 b_k_split_offset = blockIdx.z * karg.KRead; } + // Calculate A scale offset + if constexpr(is_same_v) + { + scale_a_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + else if constexpr(is_same_v) + { + scale_a_k_split_offset = + blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleA; + } + + // Calculate B scale offset + if constexpr(is_same_v) + { + scale_b_k_split_offset = + blockIdx.z * (karg.KRead / ScaleBlockK) * karg.StrideScaleB; + } + else if constexpr(is_same_v) + { + scale_b_k_split_offset = blockIdx.z * (karg.KRead / ScaleBlockK); + } + if(blockIdx.z < static_cast(karg.KBatch - 1)) { karg.K = karg.KRead; @@ -685,6 +727,8 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 index_t a_k_split_offset; index_t b_k_split_offset; + index_t scale_a_k_split_offset; // A scale matrix offset + index_t scale_b_k_split_offset; // B scale matrix offset }; __device__ static constexpr auto GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1() @@ -1221,8 +1265,10 @@ struct GridwiseGemmMultiD_ABScale_xdl_cshuffle_v3 const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N( problem.M, problem.MPadded, problem.N, problem.NPadded, problem.StrideC); - const auto a_scale_grid_desc_am_ak = MakeAScaleGridDesciptor_M_K(problem.M, problem.K); - const auto b_scale_grid_desc_bn_ak = MakeBScaleGridDesciptor_N_K(problem.N, problem.K); + const auto a_scale_grid_desc_am_ak = + MakeAScaleGridDesciptor_M_K(problem.M, problem.K, problem.StrideScaleA); + const auto b_scale_grid_desc_bn_ak = + MakeBScaleGridDesciptor_N_K(problem.N, problem.K, problem.StrideScaleB); const auto c_grid_desc_mblock_mperblock_nblock_nperblock = MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock( diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp index faf10c2cce..d4ddbafeee 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_ab_scale.hpp @@ -16,7 +16,231 @@ namespace ck { namespace tensor_operation { namespace device { namespace instance { + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_WMMA_FP8 +// Row, Col +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +// Row, Row +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +// Col, Row +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances); +#endif +#ifdef CK_USE_XDL // Row, Col void add_device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector>>& instances); #endif +#endif template -struct DeviceOperationInstanceFactory, - CLayout, - A0DataType, - A1DataType, - B0DataType, - B1DataType, - Tuple<>, - CDataType, - 1, - 128, - 128, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough, - ck::tensor_operation::element_wise::PassThrough>> +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_ABScaleSplitK, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>> +{ + using DeviceOp = DeviceGemmMultipleD_ABScaleSplitK, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_XDL + // No XDL instances for DeviceGemmMultipleABDSplitK with Add at the moment +#endif +#ifdef CK_USE_WMMA_FP8 + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + else if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + op_ptrs); + + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + op_ptrs); + add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + op_ptrs); + } + } +#endif +#endif + return op_ptrs; + } +}; + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_ABScale, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>> { using DeviceOp = DeviceGemmMultipleD_ABScale; + PassThrough, + PassThrough, + PassThrough>; static auto GetInstances() { std::vector> op_ptrs; #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#ifdef CK_USE_XDL if constexpr(is_same_v && is_same_v && is_same_v) { @@ -328,6 +655,33 @@ struct DeviceOperationInstanceFactory, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>; + + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA_FP8 #endif return op_ptrs; } diff --git a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp index a8d9545194..d660c18fd0 100644 --- a/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp +++ b/library/include/ck/library/tensor_operation_instance/gpu/gemm_blockscale_wp.hpp @@ -17,6 +17,47 @@ namespace tensor_operation { namespace device { namespace instance { #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) +#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8)) +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances); +#endif // CK_USE_WMMA && CK_USE_WMMA_FP8 + +#ifdef CK_USE_XDL void add_device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( std::vector>>& instances); #endif +#endif + +template +struct DeviceOperationInstanceFactory< + ck::tensor_operation::device::DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>> +{ + using DeviceOp = DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + + static auto GetInstances() + { + std::vector> op_ptrs; + +#if defined(CK_USE_XDL) + // No XDL instances for DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK at the moment +#endif // CK_USE_XDL + +#if(defined(CK_USE_WMMA) && defined(CK_USE_WMMA_FP8)) +#if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) + if constexpr(is_same_v && is_same_v && + is_same_v) + { + if constexpr(is_same_v && is_same_v && + is_same_v) + { + add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + op_ptrs); + + add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + op_ptrs); + } + } +#endif +#endif // CK_USE_WMMA && CK_USE_WMMA_FP8 + + return op_ptrs; + } +}; template > op_ptrs; +#ifdef CK_USE_XDL #if(defined(CK_ENABLE_BF16) || defined(CK_ENABLE_FP8)) if constexpr(is_same_v && is_same_v && is_same_v) @@ -162,6 +280,35 @@ struct DeviceOperationInstanceFactory< } } #endif +#endif + +#if defined(CK_USE_WMMA) + // Reuse DeviceGemmMultipleD_BlockScale_BPreshuffleSplitK instances + using Wrapper = DeviceGemmMultipleD_BlockScale_BPreshuffleWrapper< + ALayout, + BLayout, + Tuple<>, + CLayout, + A0DataType, + A1DataType, + B0DataType, + B1DataType, + Tuple<>, + CDataType, + 1, + 128, + 128, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough, + ck::tensor_operation::element_wise::PassThrough>; + auto new_op_ptrs = + DeviceOperationInstanceFactory::GetInstances(); + for(auto& op_ptr : new_op_ptrs) + { + op_ptrs.emplace_back(std::make_unique(std::move(op_ptr))); + } +#endif // CK_USE_WMMA + return op_ptrs; } }; diff --git a/library/src/tensor_operation_instance/gpu/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/CMakeLists.txt index ef037526ca..575e14d5bb 100644 --- a/library/src/tensor_operation_instance/gpu/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/CMakeLists.txt @@ -103,6 +103,16 @@ function(add_instance_library INSTANCE_NAME) message(DEBUG "removing gemm_universal_preshuffle_f8 instance ${source} ") list(REMOVE_ITEM ARGN "${source}") endif() + # Do not build gemm_ab_scale_f8 for any targets except gfx94, gfx95 and gfx12 + if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950" OR INST_TARGETS MATCHES "gfx12") AND (source_name MATCHES "gemm_ab_scale") AND (source_name MATCHES "_f8_f8_")) + message(DEBUG "removing gemm_ab_scale_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() + # Do not build gemm_blockscale_wp_f8 for any targets except gfx94, gfx95 and gfx12 + if(NOT (INST_TARGETS MATCHES "gfx942" OR INST_TARGETS MATCHES "gfx950" OR INST_TARGETS MATCHES "gfx12") AND (source_name MATCHES "gemm_blockscale_wp") AND (source_name MATCHES "_f8_f8_")) + message(DEBUG "removing gemm_blockscale_wp_f8 instance ${source} ") + list(REMOVE_ITEM ARGN "${source}") + endif() # Only build tf32 instances for gfx942 & gfx950 if(source_name MATCHES "_tf32_") if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32)) @@ -300,7 +310,7 @@ FOREACH(subdir_path ${dir_list}) message(DEBUG "Found gemm_multiply_multiply instances, but gfx94/gfx95/gfx11/gfx12 not on the target list. Skipping. ${cmake_instance}") set(add_inst 0) endif() - if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle|gemm_blockscale" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) + if(("${cmake_instance}" MATCHES "gemm_universal_preshuffle|gemm_blockscale|gemm_ab_scale" AND "${cmake_instance}" MATCHES "_f8_" ) AND (NOT INST_TARGETS MATCHES "gfx94|gfx95|gfx12") AND (NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH)) message(DEBUG "Found gemm_f8 instances, but gfx94/gfx95 not on the target list. Skipping.") set(add_inst 0) endif() diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt index a315db8bdd..0512b01175 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/CMakeLists.txt @@ -1,21 +1,38 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS set(GEMM_AB_SCALE_INSTANCES) list(APPEND GEMM_AB_SCALE_INSTANCES # Row, Col + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + # Row, Row + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + # Col, Row + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp + device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp + device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -27,11 +44,13 @@ set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_s set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + # Row, Row set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") + # Col, Row set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") set_source_files_properties(device_gemm_ab_scale_xdl_f8_f8_bf16/device_gemm_ab_scale_xdl_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp PROPERTIES COMPILE_OPTIONS ";-mllvm;-greedy-reverse-local-assignment=1") diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp new file mode 100644 index 0000000000..a4058ca1c2 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp @@ -0,0 +1,79 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 0, S< 8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Col, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, S< 8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..ad0667dd10 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..dbdfd41e32 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..1380df5291 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..90dbb9c9d5 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_km_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp new file mode 100644 index 0000000000..c45adb91c6 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp @@ -0,0 +1,80 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 64, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + // Memory friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Row, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..766279520a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..b837c35810 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..2fc87ba6ad --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..2188a64c98 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_kn_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp new file mode 100644 index 0000000000..cc1be58946 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -0,0 +1,95 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_ab_scale.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Compute friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 64, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 4>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 64, 16, 16, 16, 16, 4, 2, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 64, 16, 16, 16, 16, 2, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 64, 16, 16, 16, 16, 4, 2, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 64, 16, 16, 16, 16, 2, 1, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +template +using device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //#######################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle | CShuffle | CBlockTransferClusterLengths | CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //#######################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | |Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat | NRepeat | _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVector| Pipeline| Pipeline| + //#######################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle | PerShuffle | | | Scheduler| Version| + //#######################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + + // Memory friendly + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 2, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 256, 8, 16, 16, 16, 1, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<4>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 1, 4, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 2, 4, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 2, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_ABScale_Wmma_CShuffle_V3< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..3c140ef980 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp new file mode 100644 index 0000000000..d68b755506 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp @@ -0,0 +1,37 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp new file mode 100644 index 0000000000..5822fd0b2a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp new file mode 100644 index 0000000000..f4661891d1 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_ab_scale/device_gemm_ab_scale_wmma_f8_f8_bf16/device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_v1_kpadding_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& instances) +{ + add_device_operation_instances( + instances, + device_gemm_ab_scale_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt index b37a22d895..dd7596447e 100644 --- a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright (c) Advanced Micro Devices, Inc., or its affiliates. # SPDX-License-Identifier: MIT -# ONLY XDL_KERNELS +# ONLY XDL_AND_WMMA_KERNELS if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") set(GEMM_BLOCKSCALE_WP_INSTANCES) @@ -10,6 +10,9 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx9[45]|gfx12") device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_comp_kpadding_instance.cpp device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_default_instance.cpp device_gemm_blockscale_wp_xdl_f8_f8_bf16/device_gemm_blockscale_wp_xdl_f8_f8_bf16_mk_nk_mn_128_128_128_mem_v1_kpadding_instance.cpp + + device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp + device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp ) check_cxx_compiler_flag("-mllvm --misched-bottomup=1" HAS_MISCHED_BOTTOMUP) check_cxx_compiler_flag("-mllvm --misched-prera-direction=bottomup" HAS_MISCHED_PRERA_DIRECTION) diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp new file mode 100644 index 0000000000..023d1ac2b8 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp @@ -0,0 +1,77 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle_v3_blockscale_bpreshuffle.hpp" + +#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +using F8 = f8_t; +using BF16 = bhalf_t; +using F32 = float; + +using Row = tensor_layout::gemm::RowMajor; +using Col = tensor_layout::gemm::ColumnMajor; + +template +using S = Sequence; + +using PassThrough = element_wise::PassThrough; + +static constexpr auto GemmDefault = GemmSpecialization::Default; +static constexpr auto GemmKPadding = GemmSpecialization::KPadding; +static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding; +static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding; + +static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave; + +template +using device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances = std::tuple< + // clang-format off + //######################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //######################################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //######################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //######################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 128, 128, 128, 16, 16, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S <1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; + +template +using device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances = std::tuple< + // clang-format off + //######################################################| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| Scale| Scale| Scale| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CDEBlockTransferClusterLengths| CShuffleBlockTransfer| Block-wiseGemm| Block-wiseGemm| + //######################################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline| + //######################################################| | | | | | | | | | | Operation| Operation| Operation| | | M| N| K| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| | | Scheduler| Verision| + //######################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 256, 128, 8, 16, 16, 16, 1, 2, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 128, 8, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 128, 8, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 0, S< 8, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 16, 128, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 128, 1, 128, 128, 16, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 8, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 256, 128, 16, 16, 16, 16, 2, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 128, 16, 16, 16, 16, 1, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 128, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 32, 64, 256, 16, 16, 16, 16, 1, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 256, 128, 16, 16, 16, 16, 4, 2, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 128, 16, 16, 16, 16, 4, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 128, 16, 16, 16, 16, 2, 1, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S< 8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 128, 256, 16, 16, 16, 16, 4, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 16, 1, 16>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8>, + DeviceGemmMultiD_BlockScale_Wmma_CShuffle_V3_BPreshuffle< Row, Col, Tuple<>, Row, F8, F32, F8, F32, Tuple<>, BF16, F32, F32, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 1, 128, 128, 64, 64, 256, 16, 16, 16, 16, 2, 1, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1, F8> + // clang-format on + >; +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp new file mode 100644 index 0000000000..59fe63421a --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_comp_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_comp_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp new file mode 100644 index 0000000000..2b5670ead3 --- /dev/null +++ b/library/src/tensor_operation_instance/gpu/gemm_blockscale_wp/device_gemm_blockscale_wp_wmma_f8_f8_bf16/device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128_mem_default_instance.cpp @@ -0,0 +1,38 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_128_128_128.hpp" + +namespace ck { +namespace tensor_operation { +namespace device { +namespace instance { + +void add_device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_default_instances( + std::vector, + Row, + F8, + F32, + F8, + F32, + Tuple<>, + BF16, + 1, + 128, + 128, + PassThrough, + PassThrough, + PassThrough>>>& + instances) +{ + add_device_operation_instances( + instances, + device_gemm_blockscale_wp_wmma_f8_f8_bf16_mk_nk_mn_1_128_128_mem_instances{}); +} + +} // namespace instance +} // namespace device +} // namespace tensor_operation +} // namespace ck diff --git a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp index 5396a52e21..f3055575ea 100644 --- a/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp +++ b/profiler/include/profiler/profile_gemm_ab_scale_impl.hpp @@ -109,8 +109,8 @@ bool profile_gemm_ab_scale_impl(int do_verification, case 1: a0_m_k.GenerateTensorValue(GeneratorTensor_2{-2, 2}); b0_k_n.GenerateTensorValue(GeneratorTensor_2{-2, 2}); - a1_m_k.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); - b1_k_n.GenerateTensorValue(GeneratorTensor_3{0, 1.0}); + a1_m_k.GenerateTensorValue(GeneratorTensor_2{-1, 2}); + b1_k_n.GenerateTensorValue(GeneratorTensor_2{-1, 2}); break; default: a0_m_k.GenerateTensorValue(GeneratorTensor_3{-0.5, 0.5}); @@ -302,7 +302,7 @@ bool profile_gemm_ab_scale_impl(int do_verification, float gb_per_sec = num_btype / 1.E6 / ave_time; std::cout << "Perf: " << std::setw(10) << ave_time << " ms, " << tflops << " TFlops, " - << gb_per_sec << " GB/s, " << op_name << std::endl; + << gb_per_sec << " GB/s, " << op_name << ", KBatch " << KBatch << std::endl; if(tflops > best_tflops) { diff --git a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp index 49fef5a0fc..8642cc59e6 100644 --- a/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp +++ b/profiler/include/profiler/profile_gemm_blockscale_wp_impl.hpp @@ -29,7 +29,7 @@ void preShuffleBuffer(const InOutDataType* src, InOutDataType* dst, int N, int K { int KPack = 16; int NLane = NXdl; - int KLane = 64 / NLane; + int KLane = ck::get_warp_size() / NLane; int K0 = K / (KLane * KPack); // K -> K0 KLane KPack diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index b7db14945d..802f29024c 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -261,6 +261,7 @@ add_subdirectory(gemm_multiply_multiply_wp) add_subdirectory(gemm_split_k) add_subdirectory(gemm_universal) add_subdirectory(gemm_universal_preshuffle) +add_subdirectory(gemm_ab_scale) add_subdirectory(gemm_b_scale) add_subdirectory(gemm_universal_streamk) add_subdirectory(gemm_reduce) diff --git a/test/gemm_ab_scale/CMakeLists.txt b/test/gemm_ab_scale/CMakeLists.txt new file mode 100644 index 0000000000..21203aafaa --- /dev/null +++ b/test/gemm_ab_scale/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +# SPDX-License-Identifier: MIT + +if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") + add_gtest_executable(test_gemm_ab_scale test_gemm_ab_scale.cpp) + if(result EQUAL 0) + target_link_libraries(test_gemm_ab_scale PRIVATE utility device_gemm_ab_scale_instance) + endif() +endif() diff --git a/test/gemm_ab_scale/test_gemm_ab_scale.cpp b/test/gemm_ab_scale/test_gemm_ab_scale.cpp new file mode 100644 index 0000000000..01c3e2ffdb --- /dev/null +++ b/test/gemm_ab_scale/test_gemm_ab_scale.cpp @@ -0,0 +1,236 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include + +#include "gtest/gtest.h" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "test_gemm_ab_scale_util.hpp" + +using BF16 = ck::bhalf_t; +using F32 = float; +using F8 = ck::f8_t; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +namespace { + +template +struct tuple_concat; + +template +struct tuple_concat, std::tuple> +{ + using type = std::tuple; +}; + +} // namespace + +template +class TestGemmABScale_MK_NK : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmABScale_MK_KN : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +template +class TestGemmABScale_KM_KN : public ck::test::TestGemmABScale< + typename tuple_concat, Tuple>::type> +{ +}; + +// clang-format off +using KernelTypes = ::testing::Types< + // ADataType, BDataType, ComputeDataType, EDataType + std::tuple< F8, F32, F8, F32, F8, BF16> + >; +// clang-format on + +TYPED_TEST_SUITE(TestGemmABScale_MK_NK, KernelTypes); +TYPED_TEST_SUITE(TestGemmABScale_MK_KN, KernelTypes); +TYPED_TEST_SUITE(TestGemmABScale_KM_KN, KernelTypes); + +// Row Col +TYPED_TEST(TestGemmABScale_MK_NK, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, SmallMPadK) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_NK, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = K; + constexpr int StrideE = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideE); +} + +// Row Row +TYPED_TEST(TestGemmABScale_MK_KN, SmallM) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, SmallMPadK) +{ + std::vector Ms{1, 2, 3, 4, 5, 6}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, MidLargeM) +{ + std::vector Ms{127, 255, 312, 799, 1573}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideC); +} + +TYPED_TEST(TestGemmABScale_MK_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideA = K; + constexpr int StrideB = N; + constexpr int StrideE = N; + + for(int M : Ms) + this->Run(M, N, K, StrideA, StrideB, StrideE); +} + +// Col Row +TYPED_TEST(TestGemmABScale_KM_KN, SmallM) +{ + std::vector Ms{16, 32}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, SmallMPadK) +{ + std::vector Ms{16, 32}; + constexpr int N = 512; + constexpr int K = 704; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, MidLargeM) +{ + std::vector Ms{128, 256}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideC = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideC); + } +} + +TYPED_TEST(TestGemmABScale_KM_KN, Regular) +{ + std::vector Ms{512}; + constexpr int N = 512; + constexpr int K = 1024; + + constexpr int StrideB = N; + constexpr int StrideE = N; + + for(int M : Ms) + { + int StrideA = M; + this->Run(M, N, K, StrideA, StrideB, StrideE); + } +} diff --git a/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp b/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp new file mode 100644 index 0000000000..b54e5ce2e5 --- /dev/null +++ b/test/gemm_ab_scale/test_gemm_ab_scale_util.hpp @@ -0,0 +1,102 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/device/tensor_layout.hpp" +#include "include/ck/utility/data_type.hpp" +#include "profiler/profile_gemm_ab_scale_impl.hpp" + +namespace ck { +namespace test { + +template +class TestGemmABScale : public testing::Test +{ + using F32 = float; + + protected: + using ALayout = std::tuple_element_t<0, Tuple>; + using BLayout = std::tuple_element_t<1, Tuple>; + using ELayout = std::tuple_element_t<2, Tuple>; + using A0DataType = std::tuple_element_t<3, Tuple>; + using A1DataType = std::tuple_element_t<4, Tuple>; + using B0DataType = std::tuple_element_t<5, Tuple>; + using B1DataType = std::tuple_element_t<6, Tuple>; + using ComputeDataType = std::tuple_element_t<7, Tuple>; + using EDataType = std::tuple_element_t<8, Tuple>; + + public: + static constexpr ck::index_t ScaleBlockM = 1; + static constexpr ck::index_t ScaleBlockN = 128; + static constexpr ck::index_t ScaleBlockK = 128; + static constexpr bool verify_ = true; + static constexpr int init_method_ = 1; + static constexpr bool log_ = false; + static constexpr bool bench_ = false; + std::vector k_batches_; + + void SetUp() override { k_batches_ = {1, 2}; } + + void Run(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideE) + { + for(auto kb : k_batches_) + { + RunSingle(M, N, K, StrideA, StrideB, StrideE, kb); + } + } + + void RunSingle(const int M, + const int N, + const int K, + const int StrideA, + const int StrideB, + const int StrideE, + int kbatch = 1, + int n_warmup = 1, + int n_iter = 10) + { + bool pass = ck::profiler::profile_gemm_ab_scale_impl(verify_, + init_method_, + log_, + bench_, + M, + N, + K, + StrideA, + StrideB, + StrideE, + kbatch, + n_warmup, + n_iter); + EXPECT_TRUE(pass); + } +}; + +} // namespace test +} // namespace ck diff --git a/test/gemm_blockscale_wp/CMakeLists.txt b/test/gemm_blockscale_wp/CMakeLists.txt index a095968035..a0750255d1 100644 --- a/test/gemm_blockscale_wp/CMakeLists.txt +++ b/test/gemm_blockscale_wp/CMakeLists.txt @@ -2,8 +2,8 @@ # SPDX-License-Identifier: MIT if(GPU_TARGETS MATCHES "gfx9[45]|gfx12") - add_gtest_executable(test_gemm_blockscale_wp_xdl_fp8 test_gemm_blockscale_wp_xdl_fp8.cpp) + add_gtest_executable(test_gemm_blockscale_wp_fp8 test_gemm_blockscale_wp_fp8.cpp) if(result EQUAL 0) - target_link_libraries(test_gemm_blockscale_wp_xdl_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) + target_link_libraries(test_gemm_blockscale_wp_fp8 PRIVATE utility device_gemm_blockscale_wp_instance) endif() endif() diff --git a/test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp b/test/gemm_blockscale_wp/test_gemm_blockscale_wp_fp8.cpp similarity index 100% rename from test/gemm_blockscale_wp/test_gemm_blockscale_wp_xdl_fp8.cpp rename to test/gemm_blockscale_wp/test_gemm_blockscale_wp_fp8.cpp