From 545819c3625ce8f14ea71bee511ed70a0d24f9cf Mon Sep 17 00:00:00 2001 From: Cong Ma <142121551+CongMa13@users.noreply.github.com> Date: Tue, 22 Jul 2025 08:15:18 -0600 Subject: [PATCH] [CK_TILE] Migrate CK Tile examples to Tests to autorun on CI (#2421) [CK_TILE] Add new ck tile unit test * Add new ck tile unit test smoke-gemm-universal * Add new ck tile unit test smoke-gemm-basic * Add new ck tile unit test topk_softmax * Add new ck tile unit test add_rmsnorm2d_rdquant_fwd [ROCm/composable_kernel commit: f102eedfb3a17079052b5a99885b7acddef0c5a0] --- test/ck_tile/CMakeLists.txt | 4 + .../add_rmsnorm2d_rdquant/CMakeLists.txt | 26 + .../add_rmsnorm2d_rdquant_fwd.hpp | 151 ++++ .../add_rmsnorm2d_rdquant_fwd.inc | 370 +++++++++ .../add_rmsnorm2d_rdquant_fwd_bf16.cpp | 6 + .../add_rmsnorm2d_rdquant_fwd_fp16.cpp | 6 + .../add_rmsnorm2d_rdquant_fwd_api.cpp | 227 ++++++ ...norm2d_rdquant_fwd_bf16_n1024_instance.cpp | 26 + ...norm2d_rdquant_fwd_bf16_n1536_instance.cpp | 17 + ...norm2d_rdquant_fwd_bf16_n2048_instance.cpp | 18 + ...snorm2d_rdquant_fwd_bf16_n256_instance.cpp | 15 + ...norm2d_rdquant_fwd_bf16_n3072_instance.cpp | 17 + ...norm2d_rdquant_fwd_bf16_n4096_instance.cpp | 17 + ...snorm2d_rdquant_fwd_bf16_n512_instance.cpp | 17 + ...m2d_rdquant_fwd_bf16_n64_n128_instance.cpp | 15 + ...snorm2d_rdquant_fwd_bf16_n768_instance.cpp | 15 + ...norm2d_rdquant_fwd_bf16_n8192_instance.cpp | 42 + ...m2d_rdquant_fwd_bf16_n8192_tp_instance.cpp | 17 + ...norm2d_rdquant_fwd_fp16_n1024_instance.cpp | 26 + ...norm2d_rdquant_fwd_fp16_n1536_instance.cpp | 17 + ...norm2d_rdquant_fwd_fp16_n2048_instance.cpp | 18 + ...snorm2d_rdquant_fwd_fp16_n256_instance.cpp | 15 + ...norm2d_rdquant_fwd_fp16_n3072_instance.cpp | 17 + ...norm2d_rdquant_fwd_fp16_n4096_instance.cpp | 17 + ...snorm2d_rdquant_fwd_fp16_n512_instance.cpp | 17 + ...m2d_rdquant_fwd_fp16_n64_n128_instance.cpp | 15 + ...snorm2d_rdquant_fwd_fp16_n768_instance.cpp | 15 + ...norm2d_rdquant_fwd_fp16_n8192_instance.cpp | 41 + ...m2d_rdquant_fwd_fp16_n8192_tp_instance.cpp | 17 + ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 70 ++ test/ck_tile/gemm/CMakeLists.txt | 19 + .../gemm/test_gemm_pipeline_basic_bf16.cpp | 5 + .../gemm/test_gemm_pipeline_basic_bf8.cpp | 5 + .../gemm/test_gemm_pipeline_basic_fp16.cpp | 5 + .../gemm/test_gemm_pipeline_basic_fp8.cpp | 5 + .../test_gemm_pipeline_basic_run_test.inc | 313 ++++++++ .../test_gemm_pipeline_smoke_run_test.inc | 458 +++++++++++ .../gemm/test_gemm_pipeline_smoke_util.hpp | 414 ++++++++++ .../test_gemm_pipeline_universal_bf16.cpp | 16 + .../gemm/test_gemm_pipeline_universal_bf8.cpp | 16 + .../test_gemm_pipeline_universal_fp16.cpp | 16 + .../gemm/test_gemm_pipeline_universal_fp8.cpp | 16 + .../test_gemm_pipeline_universal_run_test.inc | 393 ++++++++++ test/ck_tile/layernorm2d/CMakeLists.txt | 53 ++ test/ck_tile/layernorm2d/generate.py | 730 ++++++++++++++++++ test/ck_tile/layernorm2d/layernorm2d_fwd.hpp | 70 ++ test/ck_tile/layernorm2d/layernorm2d_fwd.inc | 566 ++++++++++++++ .../layernorm2d/layernorm2d_fwd_bf16.cpp | 6 + .../layernorm2d/layernorm2d_fwd_fp16.cpp | 6 + test/ck_tile/rmsnorm2d/CMakeLists.txt | 54 ++ test/ck_tile/rmsnorm2d/generate.py | 715 +++++++++++++++++ test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.hpp | 69 ++ test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc | 619 +++++++++++++++ test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_bf16.cpp | 5 + test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_fp16.cpp | 5 + test/ck_tile/topk_softmax/CMakeLists.txt | 19 + .../topk_softmax/test_topk_softmax.hpp | 280 +++++++ .../topk_softmax/test_topk_softmax_api.cpp | 96 +++ .../topk_softmax/test_topk_softmax_api.hpp | 21 + .../topk_softmax/test_topk_softmax_bf16.cpp | 6 + .../topk_softmax/test_topk_softmax_fp16.cpp | 6 + 61 files changed, 6298 insertions(+) create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_bf16.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_fp16.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp create mode 100644 test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp create mode 100644 test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc create mode 100644 test/ck_tile/layernorm2d/CMakeLists.txt create mode 100644 test/ck_tile/layernorm2d/generate.py create mode 100644 test/ck_tile/layernorm2d/layernorm2d_fwd.hpp create mode 100644 test/ck_tile/layernorm2d/layernorm2d_fwd.inc create mode 100644 test/ck_tile/layernorm2d/layernorm2d_fwd_bf16.cpp create mode 100644 test/ck_tile/layernorm2d/layernorm2d_fwd_fp16.cpp create mode 100644 test/ck_tile/rmsnorm2d/CMakeLists.txt create mode 100644 test/ck_tile/rmsnorm2d/generate.py create mode 100644 test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.hpp create mode 100644 test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc create mode 100644 test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_bf16.cpp create mode 100644 test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_fp16.cpp create mode 100644 test/ck_tile/topk_softmax/CMakeLists.txt create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax.hpp create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax_api.cpp create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax_api.hpp create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax_bf16.cpp create mode 100644 test/ck_tile/topk_softmax/test_topk_softmax_fp16.cpp diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 648fdc7739..3e5a3034cd 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -13,3 +13,7 @@ add_subdirectory(moe_sorting) add_subdirectory(slice_tile) add_subdirectory(batched_transpose) add_subdirectory(smoothquant) +add_subdirectory(topk_softmax) +add_subdirectory(add_rmsnorm2d_rdquant) +# add_subdirectory(layernorm2d) +# add_subdirectory(rmsnorm2d) diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt new file mode 100644 index 0000000000..37774f7643 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt @@ -0,0 +1,26 @@ +function(create_tile_add_rmsnorm2d_rdquant_fwd SUFFIX) + set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "test_ck_tile_add_rmsnorm2d_rdquant_fwd_${SUFFIX}") + message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") + file(GLOB INSTANCE_SRCS instances/*.cpp) + add_test_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd_${SUFFIX}.cpp) + target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS}) + + set(TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) + + # TODO: we have to turn off this global prop, otherwise the progress bar generated + # by cmake will print too many files, execvp: /bin/sh: Argument list too long + # however, this property may affect global + # TODO: consider codegen a makefile by us + set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) +endfunction() + +if(GPU_TARGETS MATCHES "gfx9") + create_tile_add_rmsnorm2d_rdquant_fwd("fp16") + create_tile_add_rmsnorm2d_rdquant_fwd("bf16") +else() + message(DEBUG "Skipping ck tile add_rmsnorm2d_rdquant_fwd tests for current target") +endif() diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp new file mode 100644 index 0000000000..faa134e5c4 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -0,0 +1,151 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" +#include + +template +struct AddRmsnormRdquantTypeConfig; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using XDataType = ck_tile::half_t; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using XDataType = ck_tile::bf16_t; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using XDataType = ck_tile::half_t; + using YScaleDataType = float; + using QYDataType = ck_tile::fp8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using XDataType = ck_tile::bf16_t; + using YScaleDataType = float; + using QYDataType = ck_tile::fp8_t; + using ComputeDataType = float; +}; + +// runtime args +struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct add_rmsnorm2d_rdquant_fwd_traits_ +{ + using InputDataType = ck_tile::remove_cvref_t; + using QuantizedDataType = ck_tile::remove_cvref_t; + + static constexpr auto WarpSize = ck_tile::get_warp_size(); + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return total_warps * (WarpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(WarpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / WarpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(WarpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % WarpSize == 0); + return ThreadPerBlock_N_ / WarpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveX = kSaveX_; + static constexpr bool kThreePass = kThreePass_; +}; + +template +float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_rdquant_fwd_args a); + +// This is the public API, will be generated by script +struct add_rmsnorm2d_rdquant_fwd_traits +{ + std::string input_data_type; + std::string quantized_data_type; + bool save_x; +}; + +float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits, + add_rmsnorm2d_rdquant_fwd_args, + const ck_tile::stream_config&); diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc new file mode 100644 index 0000000000..b7cf891862 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc @@ -0,0 +1,370 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/host.hpp" +#include "add_rmsnorm2d_rdquant_fwd.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("quant", "int8", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string input_data_type = arg_parser.get_str("prec"); + std::string quantized_data_type = arg_parser.get_str("quant"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = AddRmsnormRdquantTypeConfig; + + using ADataType = typename TypeConfig::ADataType; + using BDataType = typename TypeConfig::BDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using XDataType = typename TypeConfig::XDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = float; + using UnquantYDataType = ck_tile::null_type; + + // host verify + ck_tile::HostTensor a_host({m, n}, {stride, 1}); + ck_tile::HostTensor b_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor x_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor x_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + b_buf.ToDevice(b_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + std::cout << "[" << input_data_type << ", " << quantized_data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX}; + + add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(), + b_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + x_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + epsilon, + m, + n, + stride}; + + float ave_time = add_rmsnorm2d_rdquant_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(BDataType) * m * n + + sizeof(GammaDataType) * n + sizeof(YScaleDataType) * m + + sizeof(QYDataType) * m * n; + + if constexpr(SaveX) + num_byte += sizeof(XDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::endl; + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + using InvRmsDataType = InputDataType; + + // Add + { + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + ck_tile::reference_binary_elementwise( + a_host, b_host, x_host_ref, op); + + if constexpr(SaveX) + { + x_buf.FromDevice(x_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err(x_host_dev, + x_host_ref, + std::string("x Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + } + + ck_tile::HostTensor y_host({m, n}); + // Rmsnorm2d + { + ck_tile::HostTensor invRms_host_ref({m}); + ck_tile::HostTensor unquant_y_host_ref({m, n}); + + // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for + // simplicity + ck_tile::reference_rmsnorm2d_fwd( + x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +bool dispatch_by_type(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + const std::string input_data_type = arg_parser.get_str("prec"); + const std::string quantized_data_type = arg_parser.get_str("quant"); + int save_x = arg_parser.get_int("save_x"); + if(input_data_type == "fp16" && quantized_data_type == "int8" && save_x) + { + return run(arg_parser); + } + else if(input_data_type == "fp16" && quantized_data_type == "int8" && !save_x) + { + return run(arg_parser); + } + else if(input_data_type == "bf16" && quantized_data_type == "int8" && save_x) + { + return run(arg_parser); + } + else if(input_data_type == "bf16" && quantized_data_type == "int8" && !save_x) + { + return run(arg_parser); + } + else if(input_data_type == "fp16" && quantized_data_type == "fp8" && save_x) + { + return run(arg_parser); + } + else if(input_data_type == "fp16" && quantized_data_type == "fp8" && !save_x) + { + return run(arg_parser); + } + else if(input_data_type == "bf16" && quantized_data_type == "fp8" && save_x) + { + return run(arg_parser); + } + else if(input_data_type == "bf16" && quantized_data_type == "fp8" && !save_x) + { + return run(arg_parser); + } + + return false; +} + +int run_add_rmsnorm2d_rdquant_combinations(std::string const& data_type) +{ + constexpr size_t PARAM_COUNT = 11; + char bufs[PARAM_COUNT][64]; + char* argv[PARAM_COUNT]; + + for(std::size_t i = 0; i < PARAM_COUNT; i++) + { + argv[i] = bufs[i]; + } + + std::vector> params = { + {"-m=99", "-n=13"}, + {"-m=17", "-n=16"}, + {"-m=1", "-n=100"}, + {"-m=4", "-n=128"}, + {"-m=80", "-n=127"}, + {"-m=22", "-n=255", "-stride=256"}, + {"-m=7", "-n=599"}, + {"-m=19", "-n=512"}, + {"-m=33", "-n=313", "-stride=1000"}, + {"-m=11", "-n=510"}, + {"-m=171", "-n=676", "-stride=818"}, + {"-m=91", "-n=636"}, + {"-m=12", "-n=768", "-stride=800"}, + {"-m=100", "-n=766", "-stride=812"}, + {"-m=31", "-n=1024"}, + {"-m=64", "-n=1000", "-stride=1004"}, + {"-m=8", "-n=1501"}, + {"-m=3", "-n=1826"}, + {"-m=5", "-n=2040"}, + {"-m=7", "-n=2734"}, + {"-m=1", "-n=3182"}, + {"-m=9", "-n=4096"}, + {"-m=3", "-n=8192"}, + {"-m=1", "-n=10547"}, + {"-m=3", "-n=17134"}, + }; + + bool result = true; + std::string pr_i = "-prec=" + data_type; + strncpy(bufs[0], "add_rmsnorm2d_rdquant_fwd", 64); + strncpy(bufs[1], pr_i.c_str(), 64); + for(size_t i = 0; i < params.size(); i++) + { + for(size_t j = 0; j < params[i].size(); j++) + { + strncpy(bufs[j + 2], params[i][j].c_str(), 64); + } + int argc = params[i].size() + 2; + + result = dispatch_by_type(argc, argv) && result; + } + return result ? 0 : -1; +} diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_bf16.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_bf16.cpp new file mode 100644 index 0000000000..1e0863fa62 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_bf16.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd.inc" + +int main() { return run_add_rmsnorm2d_rdquant_combinations("bf16"); } diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_fp16.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_fp16.cpp new file mode 100644 index 0000000000..0a0a4c4f83 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd_fp16.cpp @@ -0,0 +1,6 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd.inc" + +int main() { return run_add_rmsnorm2d_rdquant_combinations("fp16"); } diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp new file mode 100644 index 0000000000..f695ea30b2 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp @@ -0,0 +1,227 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "add_rmsnorm2d_rdquant_fwd.hpp" + +template +using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; + +template +float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t, + add_rmsnorm2d_rdquant_fwd_args a, + const ck_tile::stream_config& s) +{ + float r = -1; + // clang-format off + // rm rn tm tn vn pd x 3p + if(a.n <= 64) { + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 8192) { + if(a.n<8192){ + if(t.save_x){ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else{ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + } + else{ + if(t.save_x){ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else{ + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + } + } + else if(a.n > 8192) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + return r; + // clang-format on +} + +float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, + add_rmsnorm2d_rdquant_fwd_args a, + const ck_tile::stream_config& s) +{ + if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 && + !t.save_x) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else + throw std::runtime_error("Without supported instances!"); +} diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp new file mode 100644 index 0000000000..00df2f5082 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp @@ -0,0 +1,26 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +#if 0 +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +#endif + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp new file mode 100644 index 0000000000..2adb54c078 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp new file mode 100644 index 0000000000..39089843a2 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp new file mode 100644 index 0000000000..ddb8e1b354 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp new file mode 100644 index 0000000000..2a87614403 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp new file mode 100644 index 0000000000..045a3b8880 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp new file mode 100644 index 0000000000..1028973e74 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000..b8439a0ce9 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp new file mode 100644 index 0000000000..b24b245757 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp new file mode 100644 index 0000000000..14f0ec8525 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_instance.cpp @@ -0,0 +1,42 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp new file mode 100644 index 0000000000..3e3a6d75b9 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n8192_tp_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp new file mode 100644 index 0000000000..04d735c12c --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp @@ -0,0 +1,26 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +#if 0 +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +#endif + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp new file mode 100644 index 0000000000..5893d6c3ee --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp new file mode 100644 index 0000000000..ec9c417bf3 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp @@ -0,0 +1,18 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp new file mode 100644 index 0000000000..5bc8245106 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp new file mode 100644 index 0000000000..c022c62de6 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp new file mode 100644 index 0000000000..19172b0793 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp new file mode 100644 index 0000000000..f491d92787 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000..065f0ea4cc --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp new file mode 100644 index 0000000000..be8c6c4de5 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp @@ -0,0 +1,15 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp new file mode 100644 index 0000000000..ad2dfd931e --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_instance.cpp @@ -0,0 +1,41 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp new file mode 100644 index 0000000000..e3afa07fa4 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n8192_tp_instance.cpp @@ -0,0 +1,17 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp new file mode 100644 index 0000000000..25b10e1dc4 --- /dev/null +++ b/test/ck_tile/add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -0,0 +1,70 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "add_rmsnorm2d_rdquant_fwd.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = add_rmsnorm2d_rdquant_fwd_args; + +template +using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; + +template +float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) +{ + using InputDataType = typename Traits_::InputDataType; + using QuantizedDataType = typename Traits_::QuantizedDataType; + + using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem< + typename AddRmsnormRdquantTypeConfig::ADataType, + typename AddRmsnormRdquantTypeConfig::BDataType, + typename AddRmsnormRdquantTypeConfig::GammaDataType, + typename AddRmsnormRdquantTypeConfig::ComputeDataType, + typename AddRmsnormRdquantTypeConfig::XDataType, + typename AddRmsnormRdquantTypeConfig::YScaleDataType, + typename AddRmsnormRdquantTypeConfig::QYDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kSaveX, + Traits_::kThreePass>; + + using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; + using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/test/ck_tile/gemm/CMakeLists.txt b/test/ck_tile/gemm/CMakeLists.txt index 8f880b8fde..6cbdc1a24e 100644 --- a/test/ck_tile/gemm/CMakeLists.txt +++ b/test/ck_tile/gemm/CMakeLists.txt @@ -20,6 +20,16 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95") target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS}) + + + add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_basic_fp8 test_gemm_pipeline_basic_fp8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) else() message(DEBUG "Skipping ck_tile_gemm tests for current target") endif() @@ -27,4 +37,13 @@ endif() if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a") add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp) target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + + add_test_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) + add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp) + target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS}) endif() diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp new file mode 100644 index 0000000000..af2cb398f5 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "test_gemm_pipeline_basic_run_test.inc" + +int main() { return run_gemm_combinations("bf16"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp new file mode 100644 index 0000000000..fd8c28ef17 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "test_gemm_pipeline_basic_run_test.inc" + +int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp new file mode 100644 index 0000000000..4a93d6046a --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "test_gemm_pipeline_basic_run_test.inc" + +int main() { return run_gemm_combinations("fp16"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp new file mode 100644 index 0000000000..fd8c28ef17 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp @@ -0,0 +1,5 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#include "test_gemm_pipeline_basic_run_test.inc" + +int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc new file mode 100644 index 0000000000..9e4c036655 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc @@ -0,0 +1,313 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include + +#include +#include +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + +{ + if constexpr(Persistent) + std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl; + // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part. + constexpr bool kPadM = false; + constexpr bool kPadN = false; + constexpr bool kPadK = false; + + constexpr int kBlockPerCu = 1; + + // This part comes from the Codegen + constexpr ck_tile::index_t M_Tile = 256; + constexpr ck_tile::index_t N_Tile = 256; + constexpr ck_tile::index_t K_Tile = 64; + + constexpr ck_tile::index_t M_Warp = 2; + constexpr ck_tile::index_t N_Warp = 2; + constexpr ck_tile::index_t K_Warp = 1; + + constexpr ck_tile::index_t M_Warp_Tile = 32; + constexpr ck_tile::index_t N_Warp_Tile = 32; + constexpr ck_tile::index_t K_Warp_Tile = 16; + + using CodegenGemmShape = + ck_tile::TileGemmShape, + ck_tile::sequence, + ck_tile::sequence>; + + using TilePartitioner = ck_tile::GemmTile1DPartitioner; + + using CodegenGemmTraits = + ck_tile::TileGemmTraits; + + using CodegenPipelineProblem = ck_tile:: + GemmPipelineProblem; + + using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1; + + const auto Run = [&](const auto memory_operation_) { + constexpr auto memory_operation = memory_operation_.value; + + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem, + AccDataType, + CDataType, + ck_tile::tuple<>, + CLayout, + ck_tile::element_wise::PassThrough, + CodegenPipelineProblem::kBlockSize, + TilePartitioner::MPerBlock, + TilePartitioner::NPerBlock, + M_Warp, + N_Warp, + M_Warp_Tile, + N_Warp_Tile, + K_Warp_Tile, + CodegenPipelineProblem::TransposeC, + memory_operation>>; + + // ToDo: Will add the codegen part to test different pipeline policies in GEMM. + // Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy. + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch); + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw ArgumentsNotSupportedException( + "Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << CodegenGemmShape::GetName() << '\n' + << "problem: " << CodegenPipelineProblem::GetName() << '\n' + << "pipeline: " << CodegenGemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}" + << std::endl; + } + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + return ave_time; + }; + + if(args.k_batch == 1) + { + return Run(ck_tile::integral_constant{}); + } + else + { + return Run(ck_tile::integral_constant{}); + } +} + +template +bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "R" && b_layout == "R") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } +} + +bool run_gemm_test(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + std::string data_type = arg_parser.get_str("prec"); + std::string a_layout = arg_parser.get_str("a_layout"); + std::string b_layout = arg_parser.get_str("b_layout"); + + if(data_type == "fp16") + { + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "bf16") + { + return run_gemm_test_prec_type(a_layout, b_layout, argc, argv); + } + else if(data_type == "fp8") + { + return run_gemm_test_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "bf8") + { + return run_gemm_test_prec_type( + a_layout, b_layout, argc, argv); + } + else if(data_type == "pk_int4_t") + { + // TODO: Add support for bhalf_t ADataType + if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3) + { + return run_gemm_test_prec_type( + a_layout, b_layout, argc, argv); + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } + } + else + { + throw std::runtime_error("Unsupported data type for this operation !!!"); + } +} + +int run_gemm_combinations(std::string const& data_type) +{ + // Define possible values for each parameter + std::vector m_values = {"128", "1024"}; + std::vector n_values = {"128", "2048"}; + std::vector k_values = {"64", "128"}; + std::vector prec_values = {data_type}; + + // We'll store all our arguments as strings first + std::vector arg_strings = {"./bin/tile_example_gemm_basic", + "", // m placeholder + "", // n placeholder + "", // k placeholder + "-stride_a=0", + "-stride_b=0", + "-stride_c=0", + "", // prec placeholder + "-v=2", + "-warmup=0", + "-repeat=1"}; + + // Create an array of const char pointers for argv + constexpr size_t ARG_COUNT = 11; + constexpr size_t ARG_MAX_LEN = 64; + char args[ARG_COUNT][ARG_MAX_LEN]; + char* argv[ARG_COUNT]; + + // Run all combinations + bool is_success = true; + for(const auto& m : m_values) + { + arg_strings[1] = "-m=" + m; + + for(const auto& n : n_values) + { + arg_strings[2] = "-n=" + n; + + for(const auto& k : k_values) + { + arg_strings[3] = "-k=" + k; + + for(const auto& prec : prec_values) + { + arg_strings[7] = "-prec=" + prec; + + // Set up the argv array with pointers to the string data + for(size_t i = 0; i < ARG_COUNT; i++) + { + strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN); + argv[i] = args[i]; + } + + std::cout << "Arguments received: "; + for(size_t i = 1; i < ARG_COUNT; ++i) + { + std::cout << argv[i] << " "; + } + std::cout << std::endl; + + // Call the function with the current configuration + try + { + is_success = run_gemm_test(ARG_COUNT, argv) && is_success; + } + catch(const ArgumentsNotSupportedException& e) + { + std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n'; + // ArgumentsNotSupportedException is not an error. Do not change is_success + } + catch(const std::runtime_error& e) + { + std::cerr << "Caught runtime error: " << e.what() << '\n'; + is_success = false; + } + } + } + } + } + return is_success ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc new file mode 100644 index 0000000000..afa6912e0f --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc @@ -0,0 +1,458 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +static constexpr inline auto is_row_major(Layout layout_) +{ + return ck_tile::bool_constant, + ck_tile::tensor_layout::gemm::RowMajor>>{}; +} + +template +auto calculate_rtol_atol(const ck_tile::index_t K, + const ck_tile::index_t kbatch, + const float max_accumulated_value) +{ + using ComputeType = + std::conditional_t; + // Calculate thresholds + const auto rtol = ck_tile::get_relative_threshold( + ck_tile::integer_divide_ceil(K, kbatch)); + const auto atol = ck_tile::get_absolute_threshold( + max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch)); + // Calculate error due to split_k accumulation + const auto rtol_split_k = + ck_tile::get_relative_threshold(kbatch); + const auto atol_split_k = ck_tile::get_absolute_threshold( + max_accumulated_value, kbatch); + // Use higher threshold + return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k)); +} + +template +void permute_tensor_b(Tensor& tensor) +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits::template GemmPipeline< + UniversalGemmProblem>; + + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB(); + const ck_tile::index_t K0 = K / K1; + + Tensor tensor_copy = tensor; + + // int K0, N, K1 + for(int j = 0; j < K0; j++) + { + for(int i = 0; i < N; i++) + { + for(int jj = 0; jj < K1; jj++) + { + tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj)); + } + } + } +} + +template +void permute_vectors_i4x4_b(Tensor& tensor) +{ + const ck_tile::index_t K = tensor.get_length(0); + const ck_tile::index_t N = tensor.get_length(1); + // vector pk_i4x4 permute + for(int i = 0; i < N; i++) + { + for(int j = 0; j < K; j += 8) + { + int8_t input[8]; + + for(int k = 0; k < 4; k++) + { + int8_t i4x2 = tensor(j + k * 2, i).data; + input[k * 2 + 0] = (i4x2 >> 4) & 0xf; + input[k * 2 + 1] = (i4x2 >> 0) & 0xf; + } + + // permute 01234567->20643175 + { + int8_t hi = input[2]; + int8_t lo = input[0]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 0, i) = i4x2; + } + + { + int8_t hi = input[6]; + int8_t lo = input[4]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 2, i) = i4x2; + } + + { + int8_t hi = input[3]; + int8_t lo = input[1]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 4, i) = i4x2; + } + + { + int8_t hi = input[7]; + int8_t lo = input[5]; + int8_t i4x2 = (hi << 4) | lo; + + tensor(j + 6, i) = i4x2; + } + } + } +} + +template +float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s); + +template +float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf, + ck_tile::DeviceMem& b_k_n_dev_buf, + ck_tile::DeviceMem& c_m_n_dev_buf, + ck_tile::index_t M, + ck_tile::index_t N, + ck_tile::index_t K, + ck_tile::index_t stride_A, + ck_tile::index_t stride_B, + ck_tile::index_t stride_C, + ck_tile::index_t kbatch, + int n_warmup, + int n_repeat, + bool persistent) +{ + ck_tile::GemmHostArgs args = {a_m_k_dev_buf.GetDeviceBuffer(), + b_k_n_dev_buf.GetDeviceBuffer(), + {}, + c_m_n_dev_buf.GetDeviceBuffer(), + kbatch, + M, + N, + K, + stride_A, + stride_B, + {}, + stride_C}; + + float ave_time; + if(persistent) + { + ave_time = gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + } + else + { + ave_time = gemm( + args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50}); + } + + std::size_t flop = std::size_t(2) * M * N * K; + std::size_t num_byte = + sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N; + float tflops = static_cast(flop) / 1.E9 / ave_time; + float gb_per_sec = num_byte / 1.E6 / ave_time; + + std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K + << " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C + << " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name + << " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits::name + << " B_Type=" << DataTypeTraits::name + << " C_Type=" << DataTypeTraits::name + << " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off") + << " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, " + << tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl; + + return ave_time; +} + +template +bool run_gemm_test_with_layouts(int argc, + char* argv[], + const ALayout a_layout = ALayout{}, + const BLayout b_layout = BLayout{}, + [[maybe_unused]] const CLayout c_layout = CLayout{}) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + using AccDataType = typename GemmTypeConfig::AccDataType; + + ck_tile::index_t M = arg_parser.get_int("m"); + ck_tile::index_t N = arg_parser.get_int("n"); + ck_tile::index_t K = arg_parser.get_int("k"); + + ck_tile::index_t stride_A = arg_parser.get_int("stride_a"); + ck_tile::index_t stride_B = arg_parser.get_int("stride_b"); + ck_tile::index_t stride_C = arg_parser.get_int("stride_c"); + + ck_tile::index_t kbatch = arg_parser.get_int("split_k"); + int n_warmup = arg_parser.get_int("warmup"); + int n_repeat = arg_parser.get_int("repeat"); + ck_tile::index_t init_method = arg_parser.get_int("init"); + bool persistent = arg_parser.get_int("persistent"); + + stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout)); + stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout)); + stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{})); + + ck_tile::HostTensor a_m_k( + ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout))); + ck_tile::HostTensor b_k_n( + ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout))); + ck_tile::HostTensor c_m_n_dev_result( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + + if(init_method == 0) + { + ck_tile::FillUniformDistribution{-5.f, 5.f}(a_m_k); + ck_tile::FillUniformDistribution{-5.f, 5.f}(b_k_n); + } + else if(init_method == 1) + { + ck_tile::FillMonotonicSeq{}(a_m_k); + ck_tile::FillMonotonicSeq{}(b_k_n); + } + else if(init_method == 2) + { + ck_tile::FillUniformDistribution{1.f, 1.f}(a_m_k); + ck_tile::FillUniformDistribution{1.f, 1.f}(b_k_n); + } + else + { + a_m_k.SetZero(); + b_k_n.SetZero(); + } + + if(GemmConfig::UseStructuredSparsity) + { + ck_tile::AdjustToStructuredSparsity{}(a_m_k); + } + + ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes()); + ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes()); + + static_assert(!GemmConfig::PermuteA, "Not implemented"); + if constexpr(std::is_same_v) + { + // Permute vector pk_i4x4 data for device implementation + ck_tile::HostTensor b_k_n_dev = b_k_n; + if constexpr(GemmConfig::PermuteB) + { + permute_tensor_b(b_k_n_dev); + } + permute_vectors_i4x4_b(b_k_n_dev); + b_k_n_dev_buf.ToDevice(b_k_n_dev.data()); + } + else + { + if constexpr(GemmConfig::PermuteB) + { + std::cout << "Permute for this DataType is not implemented." << std::endl; + return false; + } + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + a_m_k_dev_buf.ToDevice(a_m_k.data()); + c_m_n_dev_buf.SetZero(); + c_m_n_dev_result.SetZero(); + + invoke_gemm, + AccDataType, + CDataType, + ALayout, + BLayout, + ck_tile::tuple<>, + CLayout>(a_m_k_dev_buf, + b_k_n_dev_buf, + c_m_n_dev_buf, + M, + N, + K, + stride_A, + stride_B, + stride_C, + kbatch, + n_warmup, + n_repeat, + persistent); + + c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data()); + bool pass = true; + + if(arg_parser.get_int("v") == 1) + { + ck_tile::HostTensor c_m_n_host_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + c_m_n_host_ref.SetZero(); + + ck_tile::reference_gemm( + a_m_k, b_k_n, c_m_n_host_ref); + const float max_accumulated_value = + *std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_host_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl; + } + else if(arg_parser.get_int("v") == 2) + { + if constexpr(std::is_same_v) + { + // Restore input for B for gpu reference + b_k_n_dev_buf.ToDevice(b_k_n.data()); + } + + // memory on host to store gpu reference result + ck_tile::HostTensor c_m_n_gpu_ref( + ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{}))); + // memory on device to store gpu reference result + ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes()); + + c_m_n_gpu_ref.SetZero(); + c_m_n_gpu_buf_ref.SetZero(); + + ADataType* d_A = static_cast(a_m_k_dev_buf.GetDeviceBuffer()); + BDataType* d_B = static_cast(b_k_n_dev_buf.GetDeviceBuffer()); + CDataType* d_C = static_cast(c_m_n_gpu_buf_ref.GetDeviceBuffer()); + + ck_tile::reference_gemm_gpu(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C); + + c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data()); + + const float max_accumulated_value = + *std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end()); + const auto rtol_atol = calculate_rtol_atol( + K, kbatch, max_accumulated_value); + pass = ck_tile::check_err(c_m_n_dev_result, + c_m_n_gpu_ref, + "Error: Incorrect results!", + rtol_atol.at(ck_tile::number<0>{}), + rtol_atol.at(ck_tile::number<1>{})); + std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{}) + << " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{}) + << std::endl; + std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl; + } + + return pass; +} diff --git a/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp new file mode 100644 index 0000000000..99a1e50a6f --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp @@ -0,0 +1,414 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/epilogue.hpp" +#include "ck_tile/ops/gemm.hpp" + +#define CK_TILE_PIPELINE_COMPUTE_V3 1 +#define CK_TILE_PIPELINE_MEMORY 2 +#define CK_TILE_PIPELINE_COMPUTE_V4 3 +#define CK_TILE_PIPELINE_COMPUTE_V5 4 + +class ArgumentsNotSupportedException : public std::logic_error +{ + public: + explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {} +}; + +// temporary workaround to get k_warp_tile based on PrecType and gfx950 or not +template +constexpr ck_tile::index_t get_k_warp_tile() +{ +#if defined(__gfx950__) + constexpr bool is_8bit_float = + std::is_same_v || std::is_same_v; + if constexpr(M_Warp_Tile == 32) + return is_8bit_float ? 64 : 16; + else + return is_8bit_float ? 128 : 32; +#else + if constexpr(M_Warp_Tile == 32) + return 16; + else + return 32; +#endif +} + +struct GemmConfigBase +{ + static constexpr bool kPadM = false; + static constexpr bool kPadN = false; + static constexpr bool kPadK = false; + + static constexpr bool PermuteA = false; + static constexpr bool PermuteB = false; + + static constexpr bool TransposeC = false; + static constexpr bool UseStructuredSparsity = false; + + static constexpr int kBlockPerCu = 1; + static constexpr ck_tile::index_t TileParitionerGroupNum = 8; + static constexpr ck_tile::index_t TileParitionerM01 = 4; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + static constexpr ck_tile::index_t NumWaveGroups = 1; +}; + +template +struct GemmConfigMemoryInterwave : public GemmConfigBase +{ + // Memory friendly for Interwave scheduler + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; + static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave; +}; + +template +struct GemmConfigMemoryIntrawave : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 32; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 4; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16; + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY; +}; + +template +struct GemmConfigComputeV3 : public GemmConfigBase +{ + // Compute V3 only support Intrawave scheduler + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; +}; + +template +struct GemmConfigComputeV3_2 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 16; + static constexpr ck_tile::index_t N_Warp_Tile = 16; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3; + + static constexpr int kBlockPerCu = 2; +}; + +template +struct GemmConfigComputeV4 : public GemmConfigBase +{ + // Compute V4 only support Intrawave scheduler + // Using the ping pong reader in the lds level + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV4_1 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 256; + static constexpr ck_tile::index_t N_Tile = 256; + static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 2; + static constexpr ck_tile::index_t N_Warp = 2; + static constexpr ck_tile::index_t K_Warp = 1; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = true; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4; +}; + +template +struct GemmConfigComputeV5 : public GemmConfigBase +{ + static constexpr ck_tile::index_t M_Tile = 128; + static constexpr ck_tile::index_t N_Tile = 128; + static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType); + + static constexpr ck_tile::index_t M_Warp = 1; + static constexpr ck_tile::index_t N_Warp = 1; + static constexpr ck_tile::index_t K_Warp = 2; + + static constexpr ck_tile::index_t M_Warp_Tile = 32; + static constexpr ck_tile::index_t N_Warp_Tile = 32; + static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile(); + + static constexpr bool DoubleSmemBuffer = false; + static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5; + static constexpr ck_tile::index_t NumWaNumWaveGroups = 2; +}; + +template +struct GemmTypeConfig; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; + // ToDo: Add more bias config to support different categories of GEMM. +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using AccDataType = float; + using CDataType = ck_tile::bf16_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::fp8_t; + using BDataType = ck_tile::fp8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::bf8_t; + using BDataType = ck_tile::bf8_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::pk_int4_t; + using AccDataType = float; + using CDataType = ck_tile::half_t; +}; + +template <> +struct GemmTypeConfig +{ + using ADataType = ck_tile::int8_t; + using BDataType = ck_tile::int8_t; + using AccDataType = int32_t; + using CDataType = int32_t; +}; + +template +struct DataTypeTraits; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp64"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int32"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf16"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "fp8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "bf8"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "pk_int4_t"; +}; + +template <> +struct DataTypeTraits +{ + static constexpr const char* name = "int8"; +}; + +template +struct PipelineTypeTraits; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4; +}; + +template <> +struct PipelineTypeTraits +{ + template + using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5; + template + using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5; +}; + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3840", "m dimension") + .insert("n", "4096", "n dimension") + .insert("k", "2048", "k dimension") + .insert("a_layout", "R", "A tensor data layout - Row by default") + .insert("b_layout", "C", "B tensor data layout - Column by default") + .insert("c_layout", "R", "C tensor data layout - Row by default") + .insert("stride_a", "0", "Tensor A stride") + .insert("stride_b", "0", "Tensor B stride") + .insert("stride_c", "0", "Tensor C stride") + .insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU") + .insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer") + .insert("split_k", "1", "splitK value") + .insert("init", "0", "0:random, 1:linear, 2:constant(1)") + .insert("persistent", "0", "0:non-persistent, 1:persistent"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +// host API +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s); diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp new file mode 100644 index 0000000000..0673272f5f --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations("bf16"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp new file mode 100644 index 0000000000..70eae12e82 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations("bf8"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp new file mode 100644 index 0000000000..8ea192c7f3 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations("fp16"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp new file mode 100644 index 0000000000..20414b4fec --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp @@ -0,0 +1,16 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include + +#include +#include +#include + +#include "ck_tile/host.hpp" +#include "test_gemm_pipeline_smoke_util.hpp" +#include "test_gemm_pipeline_smoke_run_test.inc" +#include "test_gemm_pipeline_universal_run_test.inc" + +int main() { return run_gemm_combinations("fp8"); } diff --git a/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc new file mode 100644 index 0000000000..1980648391 --- /dev/null +++ b/test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc @@ -0,0 +1,393 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +#pragma once + +template +float gemm(const ck_tile::GemmHostArgs& args, const ck_tile::stream_config& s) + +{ + using GemmShape = ck_tile::TileGemmShape< + ck_tile::sequence, + ck_tile::sequence, + ck_tile:: + sequence, + GemmConfig::PermuteA, + GemmConfig::PermuteB>; + + using TilePartitioner = + ck_tile::GemmSpatiallyLocalTilePartitioner; + + using Traits = ck_tile::TileGemmTraits; + + using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits; + using GemmPipelineProblem = + ck_tile::GemmPipelineProblem; + + using BaseGemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template UniversalGemmPipeline; + + const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile; + const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile; + const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split); + const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop); + const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop); + + float ave_time{0}; + + const auto Run = + [&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) { + constexpr bool has_hot_loop_v = has_hot_loop_.value; + constexpr auto tail_number_v = tail_number_.value; + constexpr auto scheduler = GemmConfig::Scheduler; + constexpr auto memory_operation = memory_operation_.value; + + using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem; + + using GemmPipeline = typename PipelineTypeTraits< + GemmConfig::Pipeline>::template GemmPipeline; + using GemmEpilogue = ck_tile::CShuffleEpilogue< + ck_tile::CShuffleEpilogueProblem>; + using Kernel = ck_tile::GemmKernel; + auto kargs = Kernel::MakeKernelArgs(args); + + dim3 grids; + if constexpr(Persistent) + { + grids = Kernel::MaxOccupancyGridSize(s); + } + else + { + grids = Kernel::GridSize(args.M, args.N, args.k_batch); + } + constexpr dim3 blocks = Kernel::BlockSize(); + + if(!Kernel::IsSupportedArgument(kargs)) + { + throw ArgumentsNotSupportedException( + "Wrong! Arguments not supported! Skipping gemm!\n"); + } + + if(s.log_level_ > 0) + { + std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n' + << "shape: " << GemmShape::GetName() << '\n' + << "problem: " << GemmPipelineProblem::GetName() << '\n' + << "pipeline: " << GemmPipeline::GetName() << '\n' + << "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" + << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z + << "}" << std::endl; + } + if(s.flush_cache_) + { + std::cout << "Flushing cache..." << std::endl; + static constexpr ck_tile::index_t APackedSize = + std::is_same_v ? 2 : 1; + static constexpr ck_tile::index_t BPackedSize = + std::is_same_v ? 2 : 1; + + ck_tile::HostTensor a_m(ck_tile::host_tensor_descriptor( + args.M, args.K, args.stride_A, is_row_major(ALayout{}))); + ck_tile::HostTensor b_n(ck_tile::host_tensor_descriptor( + args.K, args.N, args.stride_B, is_row_major(BLayout{}))); + + auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize; + auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize; + + ck_tile::RotatingMemWrapper rotating_mem( + kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer); + rotating_mem.Print(); + + auto run_flush_cache = [&]() { + // flush icache + ck_tile::flush_icache(); + // rotating mem + rotating_mem.Next(); + // clear c mem + if(args.k_batch > 1) + hipGetErrorString(hipMemsetAsync( + args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_)); + }; + ave_time = ck_tile::launch_kernel_preprocess( + s, + run_flush_cache, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + else + { + ave_time = + ck_tile::launch_kernel(s, + ck_tile::make_kernel( + Kernel{}, grids, blocks, 0, kargs)); + } + return ave_time; + }; + + const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) { + if(args.k_batch == 1) + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + else + { + Run(has_hot_loop_, + tail_number_, + ck_tile::integral_constant{}); + } + }; + + BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num); + return ave_time; +} + +template +bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[]) +{ + using Row = ck_tile::tensor_layout::gemm::RowMajor; + using Col = ck_tile::tensor_layout::gemm::ColumnMajor; + + if constexpr(std::is_same_v) + { + if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices when " + "BPrecType is ck_tile::pk_int4_t!"); + } + } + else + { + if(a_layout == "R" && b_layout == "R") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Row{}, Row{}); + } + else if(a_layout == "R" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Row{}, Col{}, Row{}); + } + else if(a_layout == "C" && b_layout == "R") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Row{}, Row{}); + } + else if(a_layout == "C" && b_layout == "C") + { + return run_gemm_test_with_layouts( + argc, argv, Col{}, Col{}, Row{}); + } + else + { + throw std::runtime_error("Unsupported memory layout for the input matrices!"); + } + } +} + +template