mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 06:01:23 +00:00
[CK_TILE] Migrate CK Tile examples to Tests to autorun on CI (#2421)
[CK_TILE] Add new ck tile unit test * Add new ck tile unit test smoke-gemm-universal * Add new ck tile unit test smoke-gemm-basic * Add new ck tile unit test topk_softmax * Add new ck tile unit test add_rmsnorm2d_rdquant_fwd
This commit is contained in:
@@ -13,3 +13,7 @@ add_subdirectory(moe_sorting)
|
||||
add_subdirectory(slice_tile)
|
||||
add_subdirectory(batched_transpose)
|
||||
add_subdirectory(smoothquant)
|
||||
add_subdirectory(topk_softmax)
|
||||
add_subdirectory(add_rmsnorm2d_rdquant)
|
||||
# add_subdirectory(layernorm2d)
|
||||
# add_subdirectory(rmsnorm2d)
|
||||
|
||||
26
test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt
Normal file
26
test/ck_tile/add_rmsnorm2d_rdquant/CMakeLists.txt
Normal file
@@ -0,0 +1,26 @@
|
||||
function(create_tile_add_rmsnorm2d_rdquant_fwd SUFFIX)
|
||||
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "test_ck_tile_add_rmsnorm2d_rdquant_fwd_${SUFFIX}")
|
||||
message(DEBUG "adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}")
|
||||
file(GLOB INSTANCE_SRCS instances/*.cpp)
|
||||
add_test_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} add_rmsnorm2d_rdquant_fwd_${SUFFIX}.cpp)
|
||||
target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS})
|
||||
|
||||
set(TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS})
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
endfunction()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
create_tile_add_rmsnorm2d_rdquant_fwd("fp16")
|
||||
create_tile_add_rmsnorm2d_rdquant_fwd("bf16")
|
||||
else()
|
||||
message(DEBUG "Skipping ck tile add_rmsnorm2d_rdquant_fwd tests for current target")
|
||||
endif()
|
||||
151
test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp
Normal file
151
test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp
Normal file
@@ -0,0 +1,151 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename InputDataType, typename QuantizedDataType>
|
||||
struct AddRmsnormRdquantTypeConfig;
|
||||
|
||||
template <>
|
||||
struct AddRmsnormRdquantTypeConfig<ck_tile::half_t, ck_tile::int8_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t, ck_tile::int8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::int8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AddRmsnormRdquantTypeConfig<ck_tile::half_t, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::fp8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct AddRmsnormRdquantTypeConfig<ck_tile::bf16_t, ck_tile::fp8_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YScaleDataType = float;
|
||||
using QYDataType = ck_tile::fp8_t;
|
||||
using ComputeDataType = float;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename InputDataType_,
|
||||
typename QuantizedDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveX_,
|
||||
bool kThreePass_>
|
||||
struct add_rmsnorm2d_rdquant_fwd_traits_
|
||||
{
|
||||
using InputDataType = ck_tile::remove_cvref_t<InputDataType_>;
|
||||
using QuantizedDataType = ck_tile::remove_cvref_t<QuantizedDataType_>;
|
||||
|
||||
static constexpr auto WarpSize = ck_tile::get_warp_size();
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= WarpSize;
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % WarpSize == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / WarpSize;
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (WarpSize / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(WarpSize % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / WarpSize);
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(WarpSize % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % WarpSize == 0);
|
||||
return ThreadPerBlock_N_ / WarpSize;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveX = kSaveX_;
|
||||
static constexpr bool kThreePass = kThreePass_;
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_rdquant_fwd_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct add_rmsnorm2d_rdquant_fwd_traits
|
||||
{
|
||||
std::string input_data_type;
|
||||
std::string quantized_data_type;
|
||||
bool save_x;
|
||||
};
|
||||
|
||||
float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits,
|
||||
add_rmsnorm2d_rdquant_fwd_args,
|
||||
const ck_tile::stream_config&);
|
||||
370
test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc
Normal file
370
test/ck_tile/add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.inc
Normal file
@@ -0,0 +1,370 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "add_rmsnorm2d_rdquant_fwd.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename InputDataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::int8_t>()
|
||||
{
|
||||
// due to rounding, int8 quantization might have 1 abs error
|
||||
double rtol = 1;
|
||||
double atol = 1;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("quant", "int8", "precision")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InputDataType, typename QuantizedDataType, bool SaveX>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = n;
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string input_data_type = arg_parser.get_str("prec");
|
||||
std::string quantized_data_type = arg_parser.get_str("quant");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
assert(stride >= n);
|
||||
|
||||
using TypeConfig = AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>;
|
||||
|
||||
using ADataType = typename TypeConfig::ADataType;
|
||||
using BDataType = typename TypeConfig::BDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YScaleDataType = typename TypeConfig::YScaleDataType;
|
||||
using QYDataType = typename TypeConfig::QYDataType;
|
||||
using ComputeDataType = float;
|
||||
using UnquantYDataType = ck_tile::null_type;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<ADataType> a_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<BDataType> b_host({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<XDataType> x_host_dev({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
|
||||
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
|
||||
|
||||
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {stride, 1});
|
||||
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {stride, 1});
|
||||
|
||||
ck_tile::FillUniformDistribution<ADataType>{-.5f, .5f}(a_host);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-.5f, .5f}(b_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
|
||||
ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
a_buf.ToDevice(a_host.data());
|
||||
b_buf.ToDevice(b_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
|
||||
std::cout << "[" << input_data_type << ", " << quantized_data_type << "]"
|
||||
<< " m:" << m << ", n:" << n << ", stride:" << stride << std::flush;
|
||||
|
||||
add_rmsnorm2d_rdquant_fwd_traits traits{input_data_type, quantized_data_type, SaveX};
|
||||
|
||||
add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(),
|
||||
b_buf.GetDeviceBuffer(),
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
x_buf.GetDeviceBuffer(),
|
||||
yscale_buf.GetDeviceBuffer(),
|
||||
qy_buf.GetDeviceBuffer(),
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
stride};
|
||||
|
||||
float ave_time = add_rmsnorm2d_rdquant_fwd(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(BDataType) * m * n +
|
||||
sizeof(GammaDataType) * n + sizeof(YScaleDataType) * m +
|
||||
sizeof(QYDataType) * m * n;
|
||||
|
||||
if constexpr(SaveX)
|
||||
num_byte += sizeof(XDataType) * m * n;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::endl;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
using YDataType = ComputeDataType;
|
||||
using InvRmsDataType = InputDataType;
|
||||
|
||||
// Add
|
||||
{
|
||||
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
ck_tile::reference_binary_elementwise<ADataType, BDataType, XDataType, ComputeDataType>(
|
||||
a_host, b_host, x_host_ref, op);
|
||||
|
||||
if constexpr(SaveX)
|
||||
{
|
||||
x_buf.FromDevice(x_host_dev.data());
|
||||
|
||||
auto [rtol, atol] = get_elimit<XDataType>();
|
||||
if(stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(x_host_dev,
|
||||
x_host_ref,
|
||||
std::string("x Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<QYDataType> x_host_dev_row(x_host_dev.begin() + i_r * stride,
|
||||
x_host_dev.begin() + i_r * stride +
|
||||
n);
|
||||
std::vector<QYDataType> x_host_ref_row(x_host_ref.begin() + i_r * stride,
|
||||
x_host_ref.begin() + i_r * stride +
|
||||
n);
|
||||
pass &= ck_tile::check_err(x_host_dev_row,
|
||||
x_host_ref_row,
|
||||
std::string("x[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host({m, n});
|
||||
// Rmsnorm2d
|
||||
{
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n});
|
||||
|
||||
// CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for
|
||||
// simplicity
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(
|
||||
x_host_ref, gamma_host, y_host, invRms_host_ref, unquant_y_host_ref, epsilon);
|
||||
}
|
||||
|
||||
// yscale
|
||||
{
|
||||
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
|
||||
|
||||
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
|
||||
ck_tile::reference_reduce<YDataType, ComputeDataType, YDataType>(
|
||||
y_host, y_rowwise_amax_host, ReduceAmax{});
|
||||
|
||||
auto op = [](const auto& v0) {
|
||||
return v0 /
|
||||
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
|
||||
};
|
||||
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
|
||||
y_rowwise_amax_host, yscale_host_ref, op);
|
||||
|
||||
yscale_buf.FromDevice(yscale_host_dev.mData.data());
|
||||
|
||||
auto [rtol, atol] = get_elimit<YScaleDataType>();
|
||||
pass &= ck_tile::check_err(yscale_host_dev,
|
||||
yscale_host_ref,
|
||||
std::string("yscale Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
// rowwise quantization
|
||||
{
|
||||
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
|
||||
y_host, yscale_host_ref, qy_host_ref);
|
||||
|
||||
qy_buf.FromDevice(qy_host_dev.data());
|
||||
auto [rtol, atol] = get_elimit<QYDataType>();
|
||||
|
||||
if(stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(qy_host_dev,
|
||||
qy_host_ref,
|
||||
std::string("qy Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * stride,
|
||||
qy_host_dev.begin() + i_r * stride + n);
|
||||
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * stride,
|
||||
qy_host_ref.begin() + i_r * stride + n);
|
||||
pass &= ck_tile::check_err(qy_host_dev_row,
|
||||
qy_host_ref_row,
|
||||
std::string("qy[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool dispatch_by_type(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
const std::string input_data_type = arg_parser.get_str("prec");
|
||||
const std::string quantized_data_type = arg_parser.get_str("quant");
|
||||
int save_x = arg_parser.get_int("save_x");
|
||||
if(input_data_type == "fp16" && quantized_data_type == "int8" && save_x)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, true>(arg_parser);
|
||||
}
|
||||
else if(input_data_type == "fp16" && quantized_data_type == "int8" && !save_x)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, false>(arg_parser);
|
||||
}
|
||||
else if(input_data_type == "bf16" && quantized_data_type == "int8" && save_x)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, true>(arg_parser);
|
||||
}
|
||||
else if(input_data_type == "bf16" && quantized_data_type == "int8" && !save_x)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, true>(arg_parser);
|
||||
}
|
||||
else if(input_data_type == "fp16" && quantized_data_type == "fp8" && save_x)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, true>(arg_parser);
|
||||
}
|
||||
else if(input_data_type == "fp16" && quantized_data_type == "fp8" && !save_x)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, false>(arg_parser);
|
||||
}
|
||||
else if(input_data_type == "bf16" && quantized_data_type == "fp8" && save_x)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, true>(arg_parser);
|
||||
}
|
||||
else if(input_data_type == "bf16" && quantized_data_type == "fp8" && !save_x)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, true>(arg_parser);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int run_add_rmsnorm2d_rdquant_combinations(std::string const& data_type)
|
||||
{
|
||||
constexpr size_t PARAM_COUNT = 11;
|
||||
char bufs[PARAM_COUNT][64];
|
||||
char* argv[PARAM_COUNT];
|
||||
|
||||
for(std::size_t i = 0; i < PARAM_COUNT; i++)
|
||||
{
|
||||
argv[i] = bufs[i];
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> params = {
|
||||
{"-m=99", "-n=13"},
|
||||
{"-m=17", "-n=16"},
|
||||
{"-m=1", "-n=100"},
|
||||
{"-m=4", "-n=128"},
|
||||
{"-m=80", "-n=127"},
|
||||
{"-m=22", "-n=255", "-stride=256"},
|
||||
{"-m=7", "-n=599"},
|
||||
{"-m=19", "-n=512"},
|
||||
{"-m=33", "-n=313", "-stride=1000"},
|
||||
{"-m=11", "-n=510"},
|
||||
{"-m=171", "-n=676", "-stride=818"},
|
||||
{"-m=91", "-n=636"},
|
||||
{"-m=12", "-n=768", "-stride=800"},
|
||||
{"-m=100", "-n=766", "-stride=812"},
|
||||
{"-m=31", "-n=1024"},
|
||||
{"-m=64", "-n=1000", "-stride=1004"},
|
||||
{"-m=8", "-n=1501"},
|
||||
{"-m=3", "-n=1826"},
|
||||
{"-m=5", "-n=2040"},
|
||||
{"-m=7", "-n=2734"},
|
||||
{"-m=1", "-n=3182"},
|
||||
{"-m=9", "-n=4096"},
|
||||
{"-m=3", "-n=8192"},
|
||||
{"-m=1", "-n=10547"},
|
||||
{"-m=3", "-n=17134"},
|
||||
};
|
||||
|
||||
bool result = true;
|
||||
std::string pr_i = "-prec=" + data_type;
|
||||
strncpy(bufs[0], "add_rmsnorm2d_rdquant_fwd", 64);
|
||||
strncpy(bufs[1], pr_i.c_str(), 64);
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
{
|
||||
for(size_t j = 0; j < params[i].size(); j++)
|
||||
{
|
||||
strncpy(bufs[j + 2], params[i][j].c_str(), 64);
|
||||
}
|
||||
int argc = params[i].size() + 2;
|
||||
|
||||
result = dispatch_by_type(argc, argv) && result;
|
||||
}
|
||||
return result ? 0 : -1;
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd.inc"
|
||||
|
||||
int main() { return run_add_rmsnorm2d_rdquant_combinations("bf16"); }
|
||||
@@ -0,0 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd.inc"
|
||||
|
||||
int main() { return run_add_rmsnorm2d_rdquant_combinations("fp16"); }
|
||||
@@ -0,0 +1,227 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "add_rmsnorm2d_rdquant_fwd.hpp"
|
||||
|
||||
template <typename InputDataType_,
|
||||
typename QuantizedDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveX_,
|
||||
bool kThreePass_>
|
||||
using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<InputDataType_,
|
||||
QuantizedDataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveX_,
|
||||
kThreePass_>;
|
||||
|
||||
template <typename input_data_type, typename quantized_data_type>
|
||||
float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits t,
|
||||
add_rmsnorm2d_rdquant_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
float r = -1;
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
if(a.n <= 64) {
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 4, 64, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 128) {
|
||||
if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 4, 64, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 4, 64, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 256) {
|
||||
if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 4, 64, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 4, 64, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 4, 64, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 512) {
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 4, 64, 8, true, true, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 4, 64, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 4, 64, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 4, 64, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 768) {
|
||||
if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 4, 64, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 6, 4, 64, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1,12, 4, 64, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 1024) {
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 2, 128, 8, true, true, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 2, 128, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 2, 128, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 256, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 1536) {
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 4, 64, 8, true, true, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 2, 128, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 1, 256, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 6, 1, 256, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 2048) {
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 1, 1, 256, 8, true, true, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 256, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 256, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 256, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 3072) {
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 1, 128, 8, true, true, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 1, 256, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 6, 1, 256, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 3, 1, 1024, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 4096) {
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 256, 8, true, true, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 256, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 1024, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else if(a.n <= 8192) {
|
||||
if(a.n<8192){
|
||||
if(t.save_x){
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, true, true, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, true, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, true, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, true, true, false>>(s, a);
|
||||
}
|
||||
else{
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, true, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, true, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, true, false, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, true, false, false>>(s, a);
|
||||
}
|
||||
}
|
||||
else{
|
||||
if(t.save_x){
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, false, true, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, false, true, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, false, true, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, false, true, false>>(s, a);
|
||||
}
|
||||
else{
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, false, false, false>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, false, false, false>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, false, false, false>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, false, false, false>>(s, a);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if(a.n > 8192) {
|
||||
if (a.n % 8 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 2, 1, 512, 8, true, true, true>>(s, a);
|
||||
else if (a.n % 4 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 512, 4, true, true, true>>(s, a);
|
||||
else if (a.n % 2 == 0)
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 4, 1, 1024, 2, true, true, true>>(s, a);
|
||||
else
|
||||
r = add_rmsnorm2d_rdquant_fwd_<trait_<input_data_type, quantized_data_type, 1, 8, 1, 1024, 1, true, true, true>>(s, a);
|
||||
}
|
||||
return r;
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t,
|
||||
add_rmsnorm2d_rdquant_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{
|
||||
if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 &&
|
||||
t.save_x)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t, ck_tile::int8_t>(t, a, s);
|
||||
}
|
||||
else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("int8") == 0 &&
|
||||
!t.save_x)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t, ck_tile::int8_t>(t, a, s);
|
||||
}
|
||||
else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 &&
|
||||
t.save_x)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t, ck_tile::int8_t>(t, a, s);
|
||||
}
|
||||
else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("int8") == 0 &&
|
||||
!t.save_x)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t, ck_tile::int8_t>(t, a, s);
|
||||
}
|
||||
else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 &&
|
||||
t.save_x)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t, ck_tile::fp8_t>(t, a, s);
|
||||
}
|
||||
else if(t.input_data_type.compare("fp16") == 0 && t.quantized_data_type.compare("fp8") == 0 &&
|
||||
!t.save_x)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::fp16_t, ck_tile::fp8_t>(t, a, s);
|
||||
}
|
||||
else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 &&
|
||||
t.save_x)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t, ck_tile::fp8_t>(t, a, s);
|
||||
}
|
||||
else if(t.input_data_type.compare("bf16") == 0 && t.quantized_data_type.compare("fp8") == 0 &&
|
||||
!t.save_x)
|
||||
{
|
||||
return add_rmsnorm2d_rdquant_fwd_b16_<ck_tile::bf16_t, ck_tile::fp8_t>(t, a, s);
|
||||
}
|
||||
else
|
||||
throw std::runtime_error("Without supported instances!");
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
#if 0
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,18 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,42 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::bf16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,26 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
#if 0
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true , true, false>>(const S&, A);
|
||||
#endif
|
||||
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 2, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 2, 128, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 4, 64, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 2, 128, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,18 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 1, 256, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 256, 1, true, true, false>>(const S&, A);
|
||||
|
||||
// clang-format on
|
||||
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 128, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 1, 256, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 256, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 256, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 1024, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 8, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 1, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,15 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 3, 4, 64, 4, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 6, 4, 64, 2, true , true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 12, 4, 64, 1, true , true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,41 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, false, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, false, true, false>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, false, true, false>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,17 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd x 3p
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::int8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 2, 1, 512, 8, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 512, 4, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 4, 1, 1024, 2, true, true, true>>(const S&, A);
|
||||
template float add_rmsnorm2d_rdquant_fwd_<trait_<ck_tile::fp16_t, ck_tile::fp8_t, 1, 8, 1, 1024, 1, true, true, true>>(const S&, A);
|
||||
// clang-format on
|
||||
@@ -0,0 +1,70 @@
|
||||
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "add_rmsnorm2d_rdquant_fwd.hpp"
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = add_rmsnorm2d_rdquant_fwd_args;
|
||||
|
||||
template <typename InputDataType_,
|
||||
typename QuantizedDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kTwoPass_>
|
||||
using trait_ = add_rmsnorm2d_rdquant_fwd_traits_<InputDataType_,
|
||||
QuantizedDataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveInvRms_,
|
||||
kTwoPass_>;
|
||||
|
||||
template <typename Traits_>
|
||||
float add_rmsnorm2d_rdquant_fwd_(const S& s, A a)
|
||||
{
|
||||
using InputDataType = typename Traits_::InputDataType;
|
||||
using QuantizedDataType = typename Traits_::QuantizedDataType;
|
||||
|
||||
using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem<
|
||||
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::ADataType,
|
||||
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::BDataType,
|
||||
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::GammaDataType,
|
||||
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::ComputeDataType,
|
||||
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::XDataType,
|
||||
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::YScaleDataType,
|
||||
typename AddRmsnormRdquantTypeConfig<InputDataType, QuantizedDataType>::QYDataType,
|
||||
typename Traits_::Shape,
|
||||
Traits_::kPadN,
|
||||
Traits_::kSaveX,
|
||||
Traits_::kThreePass>;
|
||||
|
||||
using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass<PipelineProblem>;
|
||||
using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kThreePass, ThreePassPipeline, OnePassPipeline>;
|
||||
|
||||
using Kernel = ck_tile::AddRmsnorm2dRdquantFwd<Pipeline>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
@@ -20,6 +20,16 @@ if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95")
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_mem PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv3 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_compv4 PRIVATE ${EXAMPLE_GEMM_COMPILE_COMPUTE_V4_OPTIONS})
|
||||
|
||||
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_fp8 test_gemm_pipeline_universal_fp8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_bf8 test_gemm_pipeline_universal_bf8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_fp8 test_gemm_pipeline_basic_fp8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf8 test_gemm_pipeline_basic_bf8.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf8 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
else()
|
||||
message(DEBUG "Skipping ck_tile_gemm tests for current target")
|
||||
endif()
|
||||
@@ -27,4 +37,13 @@ endif()
|
||||
if(GPU_TARGETS MATCHES "gfx94" OR GPU_TARGETS MATCHES "gfx95" OR GPU_TARGETS MATCHES "gfx90a")
|
||||
add_gtest_executable(test_ck_tile_gemm_pipeline_persistent test_gemm_pipeline_persistent.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_persistent PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_fp16 test_gemm_pipeline_universal_fp16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_universal_bf16 test_gemm_pipeline_universal_bf16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_universal_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_fp16 test_gemm_pipeline_basic_fp16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_fp16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
add_test_executable(test_ck_tile_gemm_pipeline_basic_bf16 test_gemm_pipeline_basic_bf16.cpp)
|
||||
target_compile_options(test_ck_tile_gemm_pipeline_basic_bf16 PRIVATE ${EXAMPLE_GEMM_COMPILE_OPTIONS})
|
||||
endif()
|
||||
|
||||
5
test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp
Normal file
5
test/ck_tile/gemm/test_gemm_pipeline_basic_bf16.cpp
Normal file
@@ -0,0 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "test_gemm_pipeline_basic_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf16"); }
|
||||
5
test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp
Normal file
5
test/ck_tile/gemm/test_gemm_pipeline_basic_bf8.cpp
Normal file
@@ -0,0 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "test_gemm_pipeline_basic_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf8"); }
|
||||
5
test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp
Normal file
5
test/ck_tile/gemm/test_gemm_pipeline_basic_fp16.cpp
Normal file
@@ -0,0 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "test_gemm_pipeline_basic_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("fp16"); }
|
||||
5
test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp
Normal file
5
test/ck_tile/gemm/test_gemm_pipeline_basic_fp8.cpp
Normal file
@@ -0,0 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#include "test_gemm_pipeline_basic_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf8"); }
|
||||
313
test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc
Normal file
313
test/ck_tile/gemm/test_gemm_pipeline_basic_run_test.inc
Normal file
@@ -0,0 +1,313 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <ostream>
|
||||
#include <string>
|
||||
#include <tuple>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
if constexpr(Persistent)
|
||||
std::cout << "WARNING: Ignoring persistent kernel option for basic gemm." << std::endl;
|
||||
// The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
|
||||
constexpr bool kPadM = false;
|
||||
constexpr bool kPadN = false;
|
||||
constexpr bool kPadK = false;
|
||||
|
||||
constexpr int kBlockPerCu = 1;
|
||||
|
||||
// This part comes from the Codegen
|
||||
constexpr ck_tile::index_t M_Tile = 256;
|
||||
constexpr ck_tile::index_t N_Tile = 256;
|
||||
constexpr ck_tile::index_t K_Tile = 64;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp = 2;
|
||||
constexpr ck_tile::index_t N_Warp = 2;
|
||||
constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
constexpr ck_tile::index_t K_Warp_Tile = 16;
|
||||
|
||||
using CodegenGemmShape =
|
||||
ck_tile::TileGemmShape<ck_tile::sequence<M_Tile, N_Tile, K_Tile>,
|
||||
ck_tile::sequence<M_Warp, N_Warp, K_Warp>,
|
||||
ck_tile::sequence<M_Warp_Tile, N_Warp_Tile, K_Warp_Tile>>;
|
||||
|
||||
using TilePartitioner = ck_tile::GemmTile1DPartitioner<CodegenGemmShape>;
|
||||
|
||||
using CodegenGemmTraits =
|
||||
ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
|
||||
|
||||
using CodegenPipelineProblem = ck_tile::
|
||||
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
|
||||
|
||||
using CodegenGemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem>;
|
||||
|
||||
const auto Run = [&](const auto memory_operation_) {
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ck_tile::tuple<>,
|
||||
CLayout,
|
||||
ck_tile::element_wise::PassThrough,
|
||||
CodegenPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
M_Warp,
|
||||
N_Warp,
|
||||
M_Warp_Tile,
|
||||
N_Warp_Tile,
|
||||
K_Warp_Tile,
|
||||
CodegenPipelineProblem::TransposeC,
|
||||
memory_operation>>;
|
||||
|
||||
// ToDo: Will add the codegen part to test different pipeline policies in GEMM.
|
||||
// Now we only use the BlockGemmASmemBSmemCRegV1DefaultPolicy.
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, CodegenGemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
const dim3 grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw ArgumentsNotSupportedException(
|
||||
"Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << CodegenGemmShape::GetName() << '\n'
|
||||
<< "problem: " << CodegenPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << CodegenGemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << "}"
|
||||
<< std::endl;
|
||||
}
|
||||
|
||||
float ave_time = ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
|
||||
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return Run(ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename APrecType, typename BPrecType = APrecType, typename CPrecType = APrecType>
|
||||
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfigBase, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::bf16_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfigBase::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_test_prec_type<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_combinations(std::string const& data_type)
|
||||
{
|
||||
// Define possible values for each parameter
|
||||
std::vector<std::string> m_values = {"128", "1024"};
|
||||
std::vector<std::string> n_values = {"128", "2048"};
|
||||
std::vector<std::string> k_values = {"64", "128"};
|
||||
std::vector<std::string> prec_values = {data_type};
|
||||
|
||||
// We'll store all our arguments as strings first
|
||||
std::vector<std::string> arg_strings = {"./bin/tile_example_gemm_basic",
|
||||
"", // m placeholder
|
||||
"", // n placeholder
|
||||
"", // k placeholder
|
||||
"-stride_a=0",
|
||||
"-stride_b=0",
|
||||
"-stride_c=0",
|
||||
"", // prec placeholder
|
||||
"-v=2",
|
||||
"-warmup=0",
|
||||
"-repeat=1"};
|
||||
|
||||
// Create an array of const char pointers for argv
|
||||
constexpr size_t ARG_COUNT = 11;
|
||||
constexpr size_t ARG_MAX_LEN = 64;
|
||||
char args[ARG_COUNT][ARG_MAX_LEN];
|
||||
char* argv[ARG_COUNT];
|
||||
|
||||
// Run all combinations
|
||||
bool is_success = true;
|
||||
for(const auto& m : m_values)
|
||||
{
|
||||
arg_strings[1] = "-m=" + m;
|
||||
|
||||
for(const auto& n : n_values)
|
||||
{
|
||||
arg_strings[2] = "-n=" + n;
|
||||
|
||||
for(const auto& k : k_values)
|
||||
{
|
||||
arg_strings[3] = "-k=" + k;
|
||||
|
||||
for(const auto& prec : prec_values)
|
||||
{
|
||||
arg_strings[7] = "-prec=" + prec;
|
||||
|
||||
// Set up the argv array with pointers to the string data
|
||||
for(size_t i = 0; i < ARG_COUNT; i++)
|
||||
{
|
||||
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
|
||||
argv[i] = args[i];
|
||||
}
|
||||
|
||||
std::cout << "Arguments received: ";
|
||||
for(size_t i = 1; i < ARG_COUNT; ++i)
|
||||
{
|
||||
std::cout << argv[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success = run_gemm_test(ARG_COUNT, argv) && is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
|
||||
// ArgumentsNotSupportedException is not an error. Do not change is_success
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
is_success = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
|
||||
}
|
||||
458
test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc
Normal file
458
test/ck_tile/gemm/test_gemm_pipeline_smoke_run_test.inc
Normal file
@@ -0,0 +1,458 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <typename Layout>
|
||||
static constexpr inline auto is_row_major(Layout layout_)
|
||||
{
|
||||
return ck_tile::bool_constant<std::is_same_v<ck_tile::remove_cvref_t<decltype(layout_)>,
|
||||
ck_tile::tensor_layout::gemm::RowMajor>>{};
|
||||
}
|
||||
|
||||
template <typename ADataType, typename BDataType, typename AccDataType, typename CDataType>
|
||||
auto calculate_rtol_atol(const ck_tile::index_t K,
|
||||
const ck_tile::index_t kbatch,
|
||||
const float max_accumulated_value)
|
||||
{
|
||||
using ComputeType =
|
||||
std::conditional_t<sizeof(ADataType) < sizeof(BDataType), ADataType, BDataType>;
|
||||
// Calculate thresholds
|
||||
const auto rtol = ck_tile::get_relative_threshold<ComputeType, CDataType, AccDataType>(
|
||||
ck_tile::integer_divide_ceil(K, kbatch));
|
||||
const auto atol = ck_tile::get_absolute_threshold<ComputeType, CDataType, AccDataType>(
|
||||
max_accumulated_value / kbatch, ck_tile::integer_divide_ceil(K, kbatch));
|
||||
// Calculate error due to split_k accumulation
|
||||
const auto rtol_split_k =
|
||||
ck_tile::get_relative_threshold<CDataType, CDataType, CDataType>(kbatch);
|
||||
const auto atol_split_k = ck_tile::get_absolute_threshold<CDataType, CDataType, CDataType>(
|
||||
max_accumulated_value, kbatch);
|
||||
// Use higher threshold
|
||||
return ck_tile::make_tuple(std::max(rtol, rtol_split_k), std::max(atol, atol_split_k));
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename Tensor,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
void permute_tensor_b(Tensor& tensor)
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity>;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmConfig::Scheduler,
|
||||
true,
|
||||
ck_tile::TailNumber::Full>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<GemmConfig::Pipeline>::template GemmPipeline<
|
||||
UniversalGemmProblem>;
|
||||
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
const ck_tile::index_t K1 = GemmPipeline::GetSmemPackB();
|
||||
const ck_tile::index_t K0 = K / K1;
|
||||
|
||||
Tensor tensor_copy = tensor;
|
||||
|
||||
// int K0, N, K1
|
||||
for(int j = 0; j < K0; j++)
|
||||
{
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int jj = 0; jj < K1; jj++)
|
||||
{
|
||||
tensor(j * N * K1 + i * K1 + jj) = tensor_copy(i * K + (j * K1 + jj));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Tensor>
|
||||
void permute_vectors_i4x4_b(Tensor& tensor)
|
||||
{
|
||||
const ck_tile::index_t K = tensor.get_length(0);
|
||||
const ck_tile::index_t N = tensor.get_length(1);
|
||||
// vector pk_i4x4 permute
|
||||
for(int i = 0; i < N; i++)
|
||||
{
|
||||
for(int j = 0; j < K; j += 8)
|
||||
{
|
||||
int8_t input[8];
|
||||
|
||||
for(int k = 0; k < 4; k++)
|
||||
{
|
||||
int8_t i4x2 = tensor(j + k * 2, i).data;
|
||||
input[k * 2 + 0] = (i4x2 >> 4) & 0xf;
|
||||
input[k * 2 + 1] = (i4x2 >> 0) & 0xf;
|
||||
}
|
||||
|
||||
// permute 01234567->20643175
|
||||
{
|
||||
int8_t hi = input[2];
|
||||
int8_t lo = input[0];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 0, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[6];
|
||||
int8_t lo = input[4];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 2, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[3];
|
||||
int8_t lo = input[1];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 4, i) = i4x2;
|
||||
}
|
||||
|
||||
{
|
||||
int8_t hi = input[7];
|
||||
int8_t lo = input[5];
|
||||
int8_t i4x2 = (hi << 4) | lo;
|
||||
|
||||
tensor(j + 6, i) = i4x2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float gemm(const ck_tile::GemmHostArgs<>& args, const ck_tile::stream_config& s);
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
typename CDEElementWise = ck_tile::element_wise::PassThrough>
|
||||
float invoke_gemm(ck_tile::DeviceMem& a_m_k_dev_buf,
|
||||
ck_tile::DeviceMem& b_k_n_dev_buf,
|
||||
ck_tile::DeviceMem& c_m_n_dev_buf,
|
||||
ck_tile::index_t M,
|
||||
ck_tile::index_t N,
|
||||
ck_tile::index_t K,
|
||||
ck_tile::index_t stride_A,
|
||||
ck_tile::index_t stride_B,
|
||||
ck_tile::index_t stride_C,
|
||||
ck_tile::index_t kbatch,
|
||||
int n_warmup,
|
||||
int n_repeat,
|
||||
bool persistent)
|
||||
{
|
||||
ck_tile::GemmHostArgs</*NumDTensor = 0*/> args = {a_m_k_dev_buf.GetDeviceBuffer(),
|
||||
b_k_n_dev_buf.GetDeviceBuffer(),
|
||||
{},
|
||||
c_m_n_dev_buf.GetDeviceBuffer(),
|
||||
kbatch,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
{},
|
||||
stride_C};
|
||||
|
||||
float ave_time;
|
||||
if(persistent)
|
||||
{
|
||||
ave_time = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
true,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time = gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
CLayout,
|
||||
false,
|
||||
CDEElementWise>(
|
||||
args, ck_tile::stream_config{nullptr, true, 1, n_warmup, n_repeat, true, true, 50});
|
||||
}
|
||||
|
||||
std::size_t flop = std::size_t(2) * M * N * K;
|
||||
std::size_t num_byte =
|
||||
sizeof(ADataType) * M * K + sizeof(BDataType) * N * K + sizeof(CDataType) * M * N;
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
|
||||
std::cout << "Run Gemm kernel with M=" << M << " N=" << N << " K=" << K
|
||||
<< " StrideA=" << stride_A << " StrideB=" << stride_B << " StrideC=" << stride_C
|
||||
<< " A_Layout=" << ALayout::name << " B_Layout =" << BLayout::name
|
||||
<< " C_Layout=" << CLayout::name << " A_Type=" << DataTypeTraits<ADataType>::name
|
||||
<< " B_Type=" << DataTypeTraits<BDataType>::name
|
||||
<< " C_Type=" << DataTypeTraits<CDataType>::name
|
||||
<< " StructuredSparsity=" << (GemmConfig::UseStructuredSparsity ? "on" : "off")
|
||||
<< " Persistent=" << (persistent ? "on" : "off") << " : " << ave_time << " ms, "
|
||||
<< tflops << " TFlops, " << gb_per_sec << " GB/s, " << std::endl;
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType = ADataType,
|
||||
typename CDataType = ADataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout>
|
||||
bool run_gemm_test_with_layouts(int argc,
|
||||
char* argv[],
|
||||
const ALayout a_layout = ALayout{},
|
||||
const BLayout b_layout = BLayout{},
|
||||
[[maybe_unused]] const CLayout c_layout = CLayout{})
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
using AccDataType = typename GemmTypeConfig<ADataType, BDataType, CDataType>::AccDataType;
|
||||
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t K = arg_parser.get_int("k");
|
||||
|
||||
ck_tile::index_t stride_A = arg_parser.get_int("stride_a");
|
||||
ck_tile::index_t stride_B = arg_parser.get_int("stride_b");
|
||||
ck_tile::index_t stride_C = arg_parser.get_int("stride_c");
|
||||
|
||||
ck_tile::index_t kbatch = arg_parser.get_int("split_k");
|
||||
int n_warmup = arg_parser.get_int("warmup");
|
||||
int n_repeat = arg_parser.get_int("repeat");
|
||||
ck_tile::index_t init_method = arg_parser.get_int("init");
|
||||
bool persistent = arg_parser.get_int("persistent");
|
||||
|
||||
stride_A = ck_tile::get_default_stride(M, K, stride_A, is_row_major(a_layout));
|
||||
stride_B = ck_tile::get_default_stride(K, N, stride_B, is_row_major(b_layout));
|
||||
stride_C = ck_tile::get_default_stride(M, N, stride_C, is_row_major(CLayout{}));
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m_k(
|
||||
ck_tile::host_tensor_descriptor(M, K, stride_A, is_row_major(a_layout)));
|
||||
ck_tile::HostTensor<BDataType> b_k_n(
|
||||
ck_tile::host_tensor_descriptor(K, N, stride_B, is_row_major(b_layout)));
|
||||
ck_tile::HostTensor<CDataType> c_m_n_dev_result(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
|
||||
if(init_method == 0)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{-5.f, 5.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{-5.f, 5.f}(b_k_n);
|
||||
}
|
||||
else if(init_method == 1)
|
||||
{
|
||||
ck_tile::FillMonotonicSeq<ADataType>{}(a_m_k);
|
||||
ck_tile::FillMonotonicSeq<BDataType>{}(b_k_n);
|
||||
}
|
||||
else if(init_method == 2)
|
||||
{
|
||||
ck_tile::FillUniformDistribution<ADataType>{1.f, 1.f}(a_m_k);
|
||||
ck_tile::FillUniformDistribution<BDataType>{1.f, 1.f}(b_k_n);
|
||||
}
|
||||
else
|
||||
{
|
||||
a_m_k.SetZero();
|
||||
b_k_n.SetZero();
|
||||
}
|
||||
|
||||
if(GemmConfig::UseStructuredSparsity)
|
||||
{
|
||||
ck_tile::AdjustToStructuredSparsity<ADataType>{}(a_m_k);
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem a_m_k_dev_buf(a_m_k.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem b_k_n_dev_buf(b_k_n.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem c_m_n_dev_buf(c_m_n_dev_result.get_element_space_size_in_bytes());
|
||||
|
||||
static_assert(!GemmConfig::PermuteA, "Not implemented");
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Permute vector pk_i4x4 data for device implementation
|
||||
ck_tile::HostTensor<BDataType> b_k_n_dev = b_k_n;
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
permute_tensor_b<GemmConfig,
|
||||
decltype(b_k_n_dev),
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(b_k_n_dev);
|
||||
}
|
||||
permute_vectors_i4x4_b(b_k_n_dev);
|
||||
b_k_n_dev_buf.ToDevice(b_k_n_dev.data());
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(GemmConfig::PermuteB)
|
||||
{
|
||||
std::cout << "Permute for this DataType is not implemented." << std::endl;
|
||||
return false;
|
||||
}
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
a_m_k_dev_buf.ToDevice(a_m_k.data());
|
||||
c_m_n_dev_buf.SetZero();
|
||||
c_m_n_dev_result.SetZero();
|
||||
|
||||
invoke_gemm<GemmConfig,
|
||||
ADataType,
|
||||
BDataType,
|
||||
ck_tile::tuple<>,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ck_tile::tuple<>,
|
||||
CLayout>(a_m_k_dev_buf,
|
||||
b_k_n_dev_buf,
|
||||
c_m_n_dev_buf,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
stride_A,
|
||||
stride_B,
|
||||
stride_C,
|
||||
kbatch,
|
||||
n_warmup,
|
||||
n_repeat,
|
||||
persistent);
|
||||
|
||||
c_m_n_dev_buf.FromDevice(c_m_n_dev_result.data());
|
||||
bool pass = true;
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
ck_tile::HostTensor<CDataType> c_m_n_host_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
c_m_n_host_ref.SetZero();
|
||||
|
||||
ck_tile::reference_gemm<ADataType, BDataType, AccDataType, CDataType>(
|
||||
a_m_k, b_k_n, c_m_n_host_ref);
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_host_ref.mData.begin(), c_m_n_host_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_host_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The CPU verification result is:" << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
else if(arg_parser.get_int("v") == 2)
|
||||
{
|
||||
if constexpr(std::is_same_v<BDataType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
// Restore input for B for gpu reference
|
||||
b_k_n_dev_buf.ToDevice(b_k_n.data());
|
||||
}
|
||||
|
||||
// memory on host to store gpu reference result
|
||||
ck_tile::HostTensor<CDataType> c_m_n_gpu_ref(
|
||||
ck_tile::host_tensor_descriptor(M, N, stride_C, is_row_major(CLayout{})));
|
||||
// memory on device to store gpu reference result
|
||||
ck_tile::DeviceMem c_m_n_gpu_buf_ref(c_m_n_gpu_ref.get_element_space_size_in_bytes());
|
||||
|
||||
c_m_n_gpu_ref.SetZero();
|
||||
c_m_n_gpu_buf_ref.SetZero();
|
||||
|
||||
ADataType* d_A = static_cast<ADataType*>(a_m_k_dev_buf.GetDeviceBuffer());
|
||||
BDataType* d_B = static_cast<BDataType*>(b_k_n_dev_buf.GetDeviceBuffer());
|
||||
CDataType* d_C = static_cast<CDataType*>(c_m_n_gpu_buf_ref.GetDeviceBuffer());
|
||||
|
||||
ck_tile::reference_gemm_gpu<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
ALayout,
|
||||
BLayout,
|
||||
CLayout>(d_A, d_B, d_C, M, N, K, stride_A, stride_B, stride_C);
|
||||
|
||||
c_m_n_gpu_buf_ref.FromDevice(c_m_n_gpu_ref.data());
|
||||
|
||||
const float max_accumulated_value =
|
||||
*std::max_element(c_m_n_gpu_ref.mData.begin(), c_m_n_gpu_ref.mData.end());
|
||||
const auto rtol_atol = calculate_rtol_atol<ADataType, BDataType, AccDataType, CDataType>(
|
||||
K, kbatch, max_accumulated_value);
|
||||
pass = ck_tile::check_err(c_m_n_dev_result,
|
||||
c_m_n_gpu_ref,
|
||||
"Error: Incorrect results!",
|
||||
rtol_atol.at(ck_tile::number<0>{}),
|
||||
rtol_atol.at(ck_tile::number<1>{}));
|
||||
std::cout << "Relative error threshold: " << rtol_atol.at(ck_tile::number<0>{})
|
||||
<< " Absolute error threshold: " << rtol_atol.at(ck_tile::number<1>{})
|
||||
<< std::endl;
|
||||
std::cout << "The GPU verification result is: " << (pass ? "correct" : "fail") << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
414
test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp
Normal file
414
test/ck_tile/gemm/test_gemm_pipeline_smoke_util.hpp
Normal file
@@ -0,0 +1,414 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/epilogue.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V3 1
|
||||
#define CK_TILE_PIPELINE_MEMORY 2
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V4 3
|
||||
#define CK_TILE_PIPELINE_COMPUTE_V5 4
|
||||
|
||||
class ArgumentsNotSupportedException : public std::logic_error
|
||||
{
|
||||
public:
|
||||
explicit ArgumentsNotSupportedException(const std::string& message) : logic_error(message) {}
|
||||
};
|
||||
|
||||
// temporary workaround to get k_warp_tile based on PrecType and gfx950 or not
|
||||
template <typename PrecType, ck_tile::index_t M_Warp_Tile>
|
||||
constexpr ck_tile::index_t get_k_warp_tile()
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, ck_tile::fp8_t> || std::is_same_v<PrecType, ck_tile::bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return 16;
|
||||
else
|
||||
return 32;
|
||||
#endif
|
||||
}
|
||||
|
||||
struct GemmConfigBase
|
||||
{
|
||||
static constexpr bool kPadM = false;
|
||||
static constexpr bool kPadN = false;
|
||||
static constexpr bool kPadK = false;
|
||||
|
||||
static constexpr bool PermuteA = false;
|
||||
static constexpr bool PermuteB = false;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
|
||||
static constexpr int kBlockPerCu = 1;
|
||||
static constexpr ck_tile::index_t TileParitionerGroupNum = 8;
|
||||
static constexpr ck_tile::index_t TileParitionerM01 = 4;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Intrawave;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
static constexpr ck_tile::index_t NumWaveGroups = 1;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryInterwave : public GemmConfigBase
|
||||
{
|
||||
// Memory friendly for Interwave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
static constexpr auto Scheduler = ck_tile::GemmPipelineScheduler::Interwave;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigMemoryIntrawave : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 4;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = sizeof(PrecType) == 2 ? 8 : 16;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_MEMORY;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3 : public GemmConfigBase
|
||||
{
|
||||
// Compute V3 only support Intrawave scheduler
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV3_2 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 16;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V3;
|
||||
|
||||
static constexpr int kBlockPerCu = 2;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4 : public GemmConfigBase
|
||||
{
|
||||
// Compute V4 only support Intrawave scheduler
|
||||
// Using the ping pong reader in the lds level
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV4_1 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 256;
|
||||
static constexpr ck_tile::index_t N_Tile = 256;
|
||||
static constexpr ck_tile::index_t K_Tile = 128 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 2;
|
||||
static constexpr ck_tile::index_t N_Warp = 2;
|
||||
static constexpr ck_tile::index_t K_Warp = 1;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V4;
|
||||
};
|
||||
|
||||
template <typename PrecType>
|
||||
struct GemmConfigComputeV5 : public GemmConfigBase
|
||||
{
|
||||
static constexpr ck_tile::index_t M_Tile = 128;
|
||||
static constexpr ck_tile::index_t N_Tile = 128;
|
||||
static constexpr ck_tile::index_t K_Tile = 64 / sizeof(PrecType);
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp = 1;
|
||||
static constexpr ck_tile::index_t N_Warp = 1;
|
||||
static constexpr ck_tile::index_t K_Warp = 2;
|
||||
|
||||
static constexpr ck_tile::index_t M_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t N_Warp_Tile = 32;
|
||||
static constexpr ck_tile::index_t K_Warp_Tile = get_k_warp_tile<PrecType, M_Warp_Tile>();
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
static constexpr ck_tile::index_t Pipeline = CK_TILE_PIPELINE_COMPUTE_V5;
|
||||
static constexpr ck_tile::index_t NumWaNumWaveGroups = 2;
|
||||
};
|
||||
|
||||
template <typename ADataType, typename BDataType = ADataType, typename CDataType = ADataType>
|
||||
struct GemmTypeConfig;
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::half_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
// ToDo: Add more bias config to support different categories of GEMM.
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf16_t, ck_tile::bf16_t, ck_tile::bf16_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf16_t;
|
||||
using BDataType = ck_tile::bf16_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::bf16_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::fp8_t, ck_tile::fp8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::fp8_t;
|
||||
using BDataType = ck_tile::fp8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::bf8_t, ck_tile::bf8_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::bf8_t;
|
||||
using BDataType = ck_tile::bf8_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::half_t, ck_tile::pk_int4_t, ck_tile::half_t>
|
||||
{
|
||||
using ADataType = ck_tile::half_t;
|
||||
using BDataType = ck_tile::pk_int4_t;
|
||||
using AccDataType = float;
|
||||
using CDataType = ck_tile::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct GemmTypeConfig<ck_tile::int8_t, ck_tile::int8_t, int32_t>
|
||||
{
|
||||
using ADataType = ck_tile::int8_t;
|
||||
using BDataType = ck_tile::int8_t;
|
||||
using AccDataType = int32_t;
|
||||
using CDataType = int32_t;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct DataTypeTraits;
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<float>
|
||||
{
|
||||
static constexpr const char* name = "fp32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<double>
|
||||
{
|
||||
static constexpr const char* name = "fp64";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<int32_t>
|
||||
{
|
||||
static constexpr const char* name = "int32";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::half_t>
|
||||
{
|
||||
static constexpr const char* name = "fp16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf16_t>
|
||||
{
|
||||
static constexpr const char* name = "bf16";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::fp8_t>
|
||||
{
|
||||
static constexpr const char* name = "fp8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::bf8_t>
|
||||
{
|
||||
static constexpr const char* name = "bf8";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::pk_int4_t>
|
||||
{
|
||||
static constexpr const char* name = "pk_int4_t";
|
||||
};
|
||||
|
||||
template <>
|
||||
struct DataTypeTraits<ck_tile::int8_t>
|
||||
{
|
||||
static constexpr const char* name = "int8";
|
||||
};
|
||||
|
||||
template <ck_tile::index_t PipelineId>
|
||||
struct PipelineTypeTraits;
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_MEMORY>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V3>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV3<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V4>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV4<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineTypeTraits<CK_TILE_PIPELINE_COMPUTE_V5>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
template <typename PipelineProblem>
|
||||
using UniversalGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3840", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("k", "2048", "k dimension")
|
||||
.insert("a_layout", "R", "A tensor data layout - Row by default")
|
||||
.insert("b_layout", "C", "B tensor data layout - Column by default")
|
||||
.insert("c_layout", "R", "C tensor data layout - Row by default")
|
||||
.insert("stride_a", "0", "Tensor A stride")
|
||||
.insert("stride_b", "0", "Tensor B stride")
|
||||
.insert("stride_c", "0", "Tensor C stride")
|
||||
.insert("v", "2", "0. No validation, 1. Validation on CPU, 2. Validation on GPU")
|
||||
.insert("prec", "fp16", "data type. fp16/bf16/fp8/bf8")
|
||||
.insert("warmup", "50", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "100", "number of iterations to benchmark the kernel")
|
||||
.insert("timer", "gpu", "gpu:gpu timer, cpu:cpu timer")
|
||||
.insert("split_k", "1", "splitK value")
|
||||
.insert("init", "0", "0:random, 1:linear, 2:constant(1)")
|
||||
.insert("persistent", "0", "0:non-persistent, 1:persistent");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
// host API
|
||||
template <typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename CLayout,
|
||||
bool Persistent = false,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s);
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_universal_bf16.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
#include "test_gemm_pipeline_universal_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf16"); }
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_universal_bf8.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
#include "test_gemm_pipeline_universal_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("bf8"); }
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_universal_fp16.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
#include "test_gemm_pipeline_universal_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("fp16"); }
|
||||
16
test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp
Normal file
16
test/ck_tile/gemm/test_gemm_pipeline_universal_fp8.cpp
Normal file
@@ -0,0 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstddef>
|
||||
#include <hip/hip_runtime.h>
|
||||
|
||||
#include <cstring>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "test_gemm_pipeline_smoke_util.hpp"
|
||||
#include "test_gemm_pipeline_smoke_run_test.inc"
|
||||
#include "test_gemm_pipeline_universal_run_test.inc"
|
||||
|
||||
int main() { return run_gemm_combinations("fp8"); }
|
||||
393
test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc
Normal file
393
test/ck_tile/gemm/test_gemm_pipeline_universal_run_test.inc
Normal file
@@ -0,0 +1,393 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename AccDataType,
|
||||
typename CDataType,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
bool Persistent,
|
||||
typename CDEElementWise>
|
||||
float gemm(const ck_tile::GemmHostArgs</*NumDTensor = 0*/>& args, const ck_tile::stream_config& s)
|
||||
|
||||
{
|
||||
using GemmShape = ck_tile::TileGemmShape<
|
||||
ck_tile::sequence<GemmConfig::M_Tile, GemmConfig::N_Tile, GemmConfig::K_Tile>,
|
||||
ck_tile::sequence<GemmConfig::M_Warp, GemmConfig::N_Warp, GemmConfig::K_Warp>,
|
||||
ck_tile::
|
||||
sequence<GemmConfig::M_Warp_Tile, GemmConfig::N_Warp_Tile, GemmConfig::K_Warp_Tile>,
|
||||
GemmConfig::PermuteA,
|
||||
GemmConfig::PermuteB>;
|
||||
|
||||
using TilePartitioner =
|
||||
ck_tile::GemmSpatiallyLocalTilePartitioner<GemmShape,
|
||||
GemmConfig::TileParitionerGroupNum,
|
||||
GemmConfig::TileParitionerM01>;
|
||||
|
||||
using Traits = ck_tile::TileGemmTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
|
||||
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<GemmConfig::kPadM,
|
||||
GemmConfig::kPadN,
|
||||
GemmConfig::kPadK,
|
||||
GemmConfig::DoubleSmemBuffer,
|
||||
ALayout,
|
||||
BLayout,
|
||||
ELayout,
|
||||
GemmConfig::TransposeC,
|
||||
GemmConfig::UseStructuredSparsity,
|
||||
Persistent,
|
||||
GemmConfig::NumWaveGroups>;
|
||||
using GemmPipelineProblem =
|
||||
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
|
||||
|
||||
using BaseGemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template UniversalGemmPipeline<GemmPipelineProblem>;
|
||||
|
||||
const ck_tile::index_t k_grain = args.k_batch * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * GemmConfig::K_Tile;
|
||||
const ck_tile::index_t num_loop = TilePartitioner::GetLoopNum(K_split);
|
||||
const bool has_hot_loop = BaseGemmPipeline::BlockHasHotloop(num_loop);
|
||||
const ck_tile::TailNumber tail_num = BaseGemmPipeline::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
float ave_time{0};
|
||||
|
||||
const auto Run =
|
||||
[&](const auto has_hot_loop_, const auto tail_number_, const auto memory_operation_) {
|
||||
constexpr bool has_hot_loop_v = has_hot_loop_.value;
|
||||
constexpr auto tail_number_v = tail_number_.value;
|
||||
constexpr auto scheduler = GemmConfig::Scheduler;
|
||||
constexpr auto memory_operation = memory_operation_.value;
|
||||
|
||||
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
scheduler,
|
||||
has_hot_loop_v,
|
||||
tail_number_v>;
|
||||
|
||||
using GemmPipeline = typename PipelineTypeTraits<
|
||||
GemmConfig::Pipeline>::template GemmPipeline<UniversalGemmProblem>;
|
||||
using GemmEpilogue = ck_tile::CShuffleEpilogue<
|
||||
ck_tile::CShuffleEpilogueProblem<ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
CDEElementWise,
|
||||
GemmPipelineProblem::kBlockSize,
|
||||
TilePartitioner::MPerBlock,
|
||||
TilePartitioner::NPerBlock,
|
||||
GemmConfig::M_Warp,
|
||||
GemmConfig::N_Warp,
|
||||
GemmConfig::M_Warp_Tile,
|
||||
GemmConfig::N_Warp_Tile,
|
||||
GemmConfig::K_Warp_Tile,
|
||||
UniversalGemmProblem::TransposeC,
|
||||
memory_operation,
|
||||
GemmConfig::NumWaveGroups>>;
|
||||
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
|
||||
auto kargs = Kernel::MakeKernelArgs(args);
|
||||
|
||||
dim3 grids;
|
||||
if constexpr(Persistent)
|
||||
{
|
||||
grids = Kernel::MaxOccupancyGridSize(s);
|
||||
}
|
||||
else
|
||||
{
|
||||
grids = Kernel::GridSize(args.M, args.N, args.k_batch);
|
||||
}
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
|
||||
if(!Kernel::IsSupportedArgument(kargs))
|
||||
{
|
||||
throw ArgumentsNotSupportedException(
|
||||
"Wrong! Arguments not supported! Skipping gemm!\n");
|
||||
}
|
||||
|
||||
if(s.log_level_ > 0)
|
||||
{
|
||||
std::cout << "Launching kernel with args: " << Kernel::GetName() << '\n'
|
||||
<< "shape: " << GemmShape::GetName() << '\n'
|
||||
<< "problem: " << GemmPipelineProblem::GetName() << '\n'
|
||||
<< "pipeline: " << GemmPipeline::GetName() << '\n'
|
||||
<< "grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
|
||||
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
if(s.flush_cache_)
|
||||
{
|
||||
std::cout << "Flushing cache..." << std::endl;
|
||||
static constexpr ck_tile::index_t APackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
static constexpr ck_tile::index_t BPackedSize =
|
||||
std::is_same_v<BDataType, ck_tile::pk_int4_t> ? 2 : 1;
|
||||
|
||||
ck_tile::HostTensor<ADataType> a_m(ck_tile::host_tensor_descriptor(
|
||||
args.M, args.K, args.stride_A, is_row_major(ALayout{})));
|
||||
ck_tile::HostTensor<BDataType> b_n(ck_tile::host_tensor_descriptor(
|
||||
args.K, args.N, args.stride_B, is_row_major(BLayout{})));
|
||||
|
||||
auto size_a_buffer = a_m.get_element_space_size_in_bytes() / APackedSize;
|
||||
auto size_b_buffer = b_n.get_element_space_size_in_bytes() / BPackedSize;
|
||||
|
||||
ck_tile::RotatingMemWrapper<ADataType, BDataType> rotating_mem(
|
||||
kargs.a_ptr, kargs.b_ptr, s.rotating_count_, size_a_buffer, size_b_buffer);
|
||||
rotating_mem.Print();
|
||||
|
||||
auto run_flush_cache = [&]() {
|
||||
// flush icache
|
||||
ck_tile::flush_icache();
|
||||
// rotating mem
|
||||
rotating_mem.Next();
|
||||
// clear c mem
|
||||
if(args.k_batch > 1)
|
||||
hipGetErrorString(hipMemsetAsync(
|
||||
args.e_ptr, 0, args.M * args.N * sizeof(CDataType), s.stream_id_));
|
||||
};
|
||||
ave_time = ck_tile::launch_kernel_preprocess(
|
||||
s,
|
||||
run_flush_cache,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time =
|
||||
ck_tile::launch_kernel(s,
|
||||
ck_tile::make_kernel<blocks.x, GemmConfig::kBlockPerCu>(
|
||||
Kernel{}, grids, blocks, 0, kargs));
|
||||
}
|
||||
return ave_time;
|
||||
};
|
||||
|
||||
const auto RunSplitk = [&](const auto has_hot_loop_, const auto tail_number_) {
|
||||
if(args.k_batch == 1)
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::set>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
Run(has_hot_loop_,
|
||||
tail_number_,
|
||||
ck_tile::integral_constant<ck_tile::memory_operation_enum,
|
||||
ck_tile::memory_operation_enum::atomic_add>{});
|
||||
}
|
||||
};
|
||||
|
||||
BaseGemmPipeline::TailHandler(RunSplitk, has_hot_loop, tail_num);
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
template <typename GemmConfig,
|
||||
typename APrecType,
|
||||
typename BPrecType = APrecType,
|
||||
typename CPrecType = APrecType>
|
||||
bool run_gemm_test_prec_type(std::string a_layout, std::string b_layout, int argc, char* argv[])
|
||||
{
|
||||
using Row = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
if constexpr(std::is_same_v<BPrecType, ck_tile::pk_int4_t>)
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices when "
|
||||
"BPrecType is ck_tile::pk_int4_t!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(a_layout == "R" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "R" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Row{}, Col{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "R")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Row{}, Row{});
|
||||
}
|
||||
else if(a_layout == "C" && b_layout == "C")
|
||||
{
|
||||
return run_gemm_test_with_layouts<GemmConfig, APrecType, BPrecType, CPrecType>(
|
||||
argc, argv, Col{}, Col{}, Row{});
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported memory layout for the input matrices!");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename PreType> typename GemmConfig>
|
||||
bool run_gemm_test(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
std::string a_layout = arg_parser.get_str("a_layout");
|
||||
std::string b_layout = arg_parser.get_str("b_layout");
|
||||
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::half_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf16")
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>, ck_tile::bf16_t>(
|
||||
a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "fp8")
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::fp8_t>,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::fp8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "bf8")
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::bf8_t>,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::bf8_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else if(data_type == "pk_int4_t")
|
||||
{
|
||||
// TODO: Add support for bhalf_t ADataType
|
||||
if constexpr(GemmConfig<ck_tile::half_t>::Pipeline == CK_TILE_PIPELINE_COMPUTE_V3)
|
||||
{
|
||||
return run_gemm_test_prec_type<GemmConfig<ck_tile::half_t>,
|
||||
ck_tile::half_t,
|
||||
ck_tile::pk_int4_t,
|
||||
ck_tile::half_t>(a_layout, b_layout, argc, argv);
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported pipeline for this operation !!!");
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("Unsupported data type for this operation !!!");
|
||||
}
|
||||
}
|
||||
|
||||
int run_gemm_combinations(std::string const& data_type)
|
||||
{
|
||||
// Define possible values for each parameter
|
||||
std::vector<std::string> m_values = {"512", "1024"};
|
||||
std::vector<std::string> n_values = {"512", "2048"};
|
||||
std::vector<std::string> k_values = {"512", "1024"};
|
||||
std::vector<std::string> prec_values = {data_type};
|
||||
|
||||
// We'll store all our arguments as strings first
|
||||
std::vector<std::string> arg_strings = {"./bin/tile_example_gemm_universal",
|
||||
"", // m placeholder
|
||||
"", // n placeholder
|
||||
"", // k placeholder
|
||||
"-stride_a=0",
|
||||
"-stride_b=0",
|
||||
"-stride_c=0",
|
||||
"", // prec placeholder
|
||||
"-v=2",
|
||||
"-warmup=0",
|
||||
"-repeat=1"};
|
||||
|
||||
// Create an array of const char pointers for argv
|
||||
constexpr size_t ARG_COUNT = 11;
|
||||
constexpr size_t ARG_MAX_LEN = 64;
|
||||
char args[ARG_COUNT][ARG_MAX_LEN];
|
||||
char* argv[ARG_COUNT];
|
||||
|
||||
// Run all combinations
|
||||
bool is_success = true;
|
||||
for(const auto& m : m_values)
|
||||
{
|
||||
arg_strings[1] = "-m=" + m;
|
||||
|
||||
for(const auto& n : n_values)
|
||||
{
|
||||
arg_strings[2] = "-n=" + n;
|
||||
|
||||
for(const auto& k : k_values)
|
||||
{
|
||||
arg_strings[3] = "-k=" + k;
|
||||
|
||||
for(const auto& prec : prec_values)
|
||||
{
|
||||
arg_strings[7] = "-prec=" + prec;
|
||||
|
||||
// Set up the argv array with pointers to the string data
|
||||
for(size_t i = 0; i < ARG_COUNT; i++)
|
||||
{
|
||||
strncpy(args[i], arg_strings[i].c_str(), ARG_MAX_LEN);
|
||||
argv[i] = args[i];
|
||||
}
|
||||
|
||||
std::cout << "Arguments received: ";
|
||||
for(size_t i = 1; i < ARG_COUNT; ++i)
|
||||
{
|
||||
std::cout << argv[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
// Call the function with the current configuration
|
||||
try
|
||||
{
|
||||
is_success =
|
||||
run_gemm_test<GemmConfigComputeV3>(ARG_COUNT, argv) && is_success;
|
||||
}
|
||||
catch(const ArgumentsNotSupportedException& e)
|
||||
{
|
||||
std::cerr << "Caught ArgumentsNotSupportedException: " << e.what() << '\n';
|
||||
// ArgumentsNotSupportedException is not an error. Do not change is_success
|
||||
}
|
||||
catch(const std::runtime_error& e)
|
||||
{
|
||||
std::cerr << "Caught runtime error: " << e.what() << '\n';
|
||||
is_success = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_success ? EXIT_SUCCESS : EXIT_FAILURE;
|
||||
}
|
||||
53
test/ck_tile/layernorm2d/CMakeLists.txt
Normal file
53
test/ck_tile/layernorm2d/CMakeLists.txt
Normal file
@@ -0,0 +1,53 @@
|
||||
function(create_tile_layernorm2d_fwd SUFFIX)
|
||||
set(TEST_CK_TILE_LAYERNORM2D_FWD "test_ck_tile_layernorm2d_fwd_${SUFFIX}")
|
||||
|
||||
message(DEBUG "adding example ${TEST_CK_TILE_LAYERNORM2D_FWD}")
|
||||
add_test_executable(${TEST_CK_TILE_LAYERNORM2D_FWD} layernorm2d_fwd_${SUFFIX}.cpp)
|
||||
target_include_directories(${TEST_CK_TILE_LAYERNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TEST_CK_TILE_LAYERNORM2D_FWD} PRIVATE ${LAYERNORM2D_FWD_GEN_BLOBS})
|
||||
|
||||
set(TEST_CK_TILE_LAYERNORM2D_FWD_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TEST_CK_TILE_LAYERNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
|
||||
|
||||
target_compile_options(${TEST_CK_TILE_LAYERNORM2D_FWD} PRIVATE ${TEST_CK_TILE_LAYERNORM2D_FWD_COMPILE_OPTIONS})
|
||||
endfunction()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
set(LAYERNORM2D_FWD_KNOWN_APIS "fwd;bwd")
|
||||
set(LAYERNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${LAYERNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(LAYERNORM2D_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(LAYERNORM2D_FWD_ENABLE_APIS ${LAYERNORM2D_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/layernorm2d_fwd_blobs.txt LAYERNORM2D_FWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${LAYERNORM2D_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${LAYERNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
|
||||
)
|
||||
|
||||
create_tile_layernorm2d_fwd("fp16")
|
||||
create_tile_layernorm2d_fwd("bf16")
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
else()
|
||||
message(DEBUG "Skipping ck tile add_rmsnorm2d_rdquant_fwd tests for current target")
|
||||
endif()
|
||||
730
test/ck_tile/layernorm2d/generate.py
Normal file
730
test/ck_tile/layernorm2d/generate.py
Normal file
@@ -0,0 +1,730 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional, Any
|
||||
import functools
|
||||
import itertools
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
def get_if_str(idx, total, lase_else = True):
|
||||
if idx == 0:
|
||||
return 'if'
|
||||
elif idx < total - 1:
|
||||
return 'else if'
|
||||
else:
|
||||
if lase_else:
|
||||
return 'else'
|
||||
else:
|
||||
return 'else if'
|
||||
|
||||
XBIAS_ENUM_STR_MAP = [
|
||||
'no',
|
||||
'xbias'] # pre-norm add bias
|
||||
|
||||
FUSED_ADD_ENUM_STR_MAP = [
|
||||
'no',
|
||||
'pras', # pre-norm
|
||||
'pra' ] # post-norm
|
||||
|
||||
FUSED_FUSED_SWEEP_STR_MAP = [
|
||||
'no',
|
||||
'dquant' ]
|
||||
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::fp16_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t',
|
||||
'fp8' : 'ck_tile::fp8_t'}
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
return 'true'
|
||||
else:
|
||||
return 'false'
|
||||
|
||||
class layernorm_fwd_codegen:
|
||||
API_TRAITS_DEFINE = """
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kXbias_ = 0,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
struct layernorm2d_fwd_traits_
|
||||
{
|
||||
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
||||
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveMeanInvStd = kSaveMeanInvStd_;
|
||||
static constexpr bool kFastFDiv = kFastFDiv_;
|
||||
static constexpr bool kWelford = kWelford_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kXbias = kXbias_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveMeanInvStd_,
|
||||
bool kFastFDiv_,
|
||||
bool kWelford_,
|
||||
bool kTwoPass_,
|
||||
int kXbias_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
using traits_ = layernorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
SmoothScaleDataType_,
|
||||
YScaleDataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveMeanInvStd_,
|
||||
kFastFDiv_,
|
||||
kWelford_,
|
||||
kTwoPass_,
|
||||
kXbias_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
"""
|
||||
API_COMMON_HEADER = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = layernorm2d_fwd_args;
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
template <typename Traits_>
|
||||
float layernorm2d_fwd_(const S& s, A a)
|
||||
{{
|
||||
using XDataType = typename Traits_::XDataType;
|
||||
using YDataType = typename Traits_::YDataType;
|
||||
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
|
||||
using YScaleDataType = typename Traits_::YScaleDataType;
|
||||
using ComputeDataType = typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
|
||||
|
||||
using PipelineTraits = ck_tile::Layernorm2dFwdTraits<Traits_::kPadN,
|
||||
Traits_::kSaveMeanInvStd,
|
||||
Traits_::kFastFDiv,
|
||||
Traits_::kWelford,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Layernorm2dXBiasEnum>(Traits_::kXbias),
|
||||
static_cast<ck_tile::Layernorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Layernorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
using PipelineProblem = ck_tile::Layernorm2dFwdPipelineProblem<
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XBiasDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::BetaDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::MeanDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvStdDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
|
||||
typename LayerNormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
|
||||
typename Traits_::Shape,
|
||||
PipelineTraits>;
|
||||
|
||||
using OnePassPipeline = ck_tile::Layernorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Layernorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, true>;
|
||||
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
||||
|
||||
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
|
||||
static constexpr bool UseRawStore = sizeof(YDataType) == 4;
|
||||
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
|
||||
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, UseRawStore, true/*max3*/>>;
|
||||
|
||||
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
|
||||
|
||||
using Epilogue = std::conditional_t<Traits_::kFusedQuant == 1, DynamicQuantEpilogue, Default2DEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::Layernorm2dFwd<Pipeline, Epilogue>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
API_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename Traits_>
|
||||
float layernorm2d_fwd_(const ck_tile::stream_config& s, layernorm2d_fwd_args a);
|
||||
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits t,
|
||||
layernorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
API_PER_DTYPE=""" {F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
|
||||
{F_per_n_case}
|
||||
}}
|
||||
"""
|
||||
API_PER_N_CASE=""" {F_if} {F_N_COND} {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
API_INNER_CASE=""" {F_if} {F_VEC_COND}
|
||||
r={F_instance_func}(s, a);
|
||||
"""
|
||||
|
||||
INSTANCE_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd_api_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// prec_i prec_o prec_sy rm rn tm tn vn pd mv rpcf welford 2p xbias add sweep
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, working_path, kernel_filter):
|
||||
self.working_path = working_path
|
||||
self.kernel_filter = kernel_filter
|
||||
|
||||
class k_xbias_enum(IntEnum):
|
||||
F_NO_XBIAS = 0
|
||||
F_ADD_XBIAS = 1
|
||||
|
||||
class k_fuesd_add_enum(IntEnum):
|
||||
F_NO_ADD = 0
|
||||
F_PRE_ADD = 1
|
||||
F_PRE_ADD_STORE_RESIDUAL = 2
|
||||
|
||||
class k_fused_sweep_enum(IntEnum):
|
||||
F_NO_SWEEP = 0
|
||||
F_RENORM = 1
|
||||
F_DYNAMIC_QUANT = 2
|
||||
|
||||
@dataclass
|
||||
class k_traits:
|
||||
F_kPadN : bool
|
||||
F_kSaveMeanInvStd : bool
|
||||
F_kTwoPass : bool
|
||||
F_kXbias : Any #: layernorm_fwd_codegen.k_bias_enum
|
||||
F_kFusedAdd : Any #: layernorm_fwd_codegen.k_fuesd_add_enum
|
||||
F_kFusedQuant : Any #: layernorm_fwd_codegen.k_fused_sweep_enum
|
||||
|
||||
@dataclass
|
||||
class k_shape:
|
||||
F_BlockTile : List[int]
|
||||
F_WarpPerBlock : List[int]
|
||||
F_WarpTile : List[int]
|
||||
F_Vector_ : List[int]
|
||||
@property
|
||||
def F_BlockSize(self) -> int:
|
||||
return functools.reduce(lambda a, b: a*b, self.F_WarpTile)
|
||||
|
||||
@dataclass
|
||||
class k_problem:
|
||||
F_XDataType : str
|
||||
F_XBiasDataType : str
|
||||
F_GammaDataType : str
|
||||
F_BetaDataType : str
|
||||
F_ComputeDataType : str
|
||||
F_YDataType : str
|
||||
F_MeanDataType : str
|
||||
F_InvStdDataType : str
|
||||
F_BlockShape : str
|
||||
F_Traits : Any #k_traits
|
||||
|
||||
@dataclass
|
||||
class k_pipeline_one_pass:
|
||||
F_Problem : Any #k_problem
|
||||
|
||||
@dataclass
|
||||
class k_pipeline_two_pass:
|
||||
F_Problem : Any #k_problem
|
||||
|
||||
@dataclass
|
||||
class default_2d_epilogue_problem:
|
||||
F_AccDataType : str
|
||||
F_ODataType : str
|
||||
F_kPadM : bool
|
||||
F_kPadN : bool
|
||||
|
||||
@dataclass
|
||||
class default_2d_epilogue:
|
||||
F_problem : Any
|
||||
|
||||
@dataclass
|
||||
class k_kernel:
|
||||
F_pipeline : Any
|
||||
F_epilogue : Any
|
||||
|
||||
@dataclass
|
||||
class h_traits:
|
||||
F_XDataType : str
|
||||
F_YDataType : str
|
||||
F_SmoothScaleDataType : str
|
||||
F_YScaleDataType : str
|
||||
F_Repeat_M : int
|
||||
F_Repeat_N : int
|
||||
F_ThreadPerBlock_M : int
|
||||
F_ThreadPerBlock_N : int
|
||||
F_Vector_N : int
|
||||
F_kPadN : bool
|
||||
F_kSaveMeanInvStd_ : bool
|
||||
F_kFastFDiv_ : bool
|
||||
F_kWelford_ : bool
|
||||
F_kTwoPass_ : bool
|
||||
F_kXbias_ : int
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveMeanInvStd_):5}, {BOOL_MAP(self.F_kFastFDiv_):5}, {BOOL_MAP(self.F_kWelford_):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass_):5}, {self.F_kXbias:4}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
return t_
|
||||
|
||||
# string when calling this kernel
|
||||
@property
|
||||
def call_name(self) -> str:
|
||||
return f'layernorm2d_fwd_<traits_<{self.trait_name}>>'
|
||||
|
||||
# string when define this kernel
|
||||
@property
|
||||
def def_name(self) -> str:
|
||||
return f'template float layernorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);'
|
||||
|
||||
# this class hold kernel under same source file
|
||||
@dataclass
|
||||
class h_instance:
|
||||
F_DataTypePair : str
|
||||
F_N : str
|
||||
F_xbias : int
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
prec_i, prec_o = self.F_DataTypePair.split(',')
|
||||
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
|
||||
nnn = f'layernorm2d_fwd_{dtype_str}_n{self.F_N}'
|
||||
if self.F_xbias != 0:
|
||||
nnn = nnn + '_' + XBIAS_ENUM_STR_MAP[self.F_xbias]
|
||||
if self.F_add != 0:
|
||||
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
|
||||
if self.F_sweep != 0:
|
||||
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
||||
return nnn
|
||||
|
||||
@property
|
||||
def instance_name(self) ->str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def content(self) ->str:
|
||||
instance_defs = ''
|
||||
for ins in self.instance_list:
|
||||
instance_defs += ins.def_name + '\n'
|
||||
return layernorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs)
|
||||
|
||||
@property
|
||||
def name_api(self) -> str:
|
||||
return 'layernorm2d_fwd_api'
|
||||
|
||||
@property
|
||||
def name_common_header(self) -> str:
|
||||
return 'layernorm2d_fwd_api_common'
|
||||
|
||||
def content_api(self, args) -> str:
|
||||
# 1 sort based on dtype
|
||||
t_dtype_dict = dict()
|
||||
blobs = self.get_blobs(args)
|
||||
for blob in blobs:
|
||||
if blob.F_DataTypePair not in t_dtype_dict:
|
||||
t_dtype_dict[blob.F_DataTypePair] = {}
|
||||
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
|
||||
|
||||
d_str = ''
|
||||
for i_d, dtype_ in enumerate(t_dtype_dict):
|
||||
blob_per_t = t_dtype_dict[dtype_]
|
||||
n_str = ''
|
||||
for i_n, n_ in enumerate(blob_per_t):
|
||||
blob_per_n = blob_per_t[n_]
|
||||
inner_str = ""
|
||||
for i_b, b_ in enumerate(blob_per_n):
|
||||
# generate single kernel instance file
|
||||
#vec_str = ""
|
||||
for i_ins, ins in enumerate(b_.instance_list):
|
||||
idx_in_n = i_b * len(b_.instance_list) + i_ins
|
||||
len_in_n = len(blob_per_n) * len(b_.instance_list)
|
||||
# _if = 'if' if i_ins == 0 else 'else if'
|
||||
if ins.F_kFusedQuant == 0:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
|
||||
elif ins.F_kFusedQuant == 1:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType)
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\")'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType)
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.xbias == {f_xbias}) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_xbias = ins.F_kXbias, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
||||
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
||||
#inner_str = inner_str + vec_str
|
||||
n_cnd = f'(a.n <= {n_})' if isinstance(n_, int) else ''
|
||||
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t), not isinstance(n_, int)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
|
||||
prec_i, prec_o = dtype_.split(',')
|
||||
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
|
||||
|
||||
api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str)
|
||||
return api_base
|
||||
|
||||
@property
|
||||
def content_common_header(self) -> str:
|
||||
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
|
||||
|
||||
def get_blobs(self, args):
|
||||
h_traits = layernorm_fwd_codegen.h_traits
|
||||
h_instance = layernorm_fwd_codegen.h_instance
|
||||
|
||||
dynamic_quant_out_dtype = ['int8', 'fp8']
|
||||
# some predefined support range
|
||||
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
||||
scale_list = [('fp32,fp32')]
|
||||
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
|
||||
('fp16,int8'), ('bf16,int8'),
|
||||
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 or fp8 out
|
||||
types_8bit = ('int8', 'fp8')
|
||||
types_16bit = ('int16', 'fp16', 'bf16')
|
||||
#fused_add_list = [0, 1, 2]
|
||||
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused dynamic quant
|
||||
xbias_list = [0, 1]
|
||||
fused_add_list = [0, 1]
|
||||
fused_sweep_list = [0, 1] # NOTE: only single pass can use fused dynamic quant
|
||||
# rm rn tm tn vn pd mv fdiv welford 2p xbias add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 8, 8, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 16, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 1, 1, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 4, 64, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 4, 64, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 2, 128, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 2, 128, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 4, 64, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 2, 128, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1, 256, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 128, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1, 256, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 1, 2, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 2, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 3, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 6, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 8, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 512, 4, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 2, True, False, True, True, False, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 8, 1,1024, 1, True, False, True, True, False, 0, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 1, 1, 1,1024, 8, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1, 256, 4, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 12, 1, 256, 2, True, False, True, True, True, 0, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 1, 4, 1,1024, 1, True, False, True, True, True, 0, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, xbias, fused_add, fused_quant in itertools.product(dtype_list, scale_list, xbias_list, fused_add_list, fused_sweep_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1:
|
||||
continue # skip non dynamic quant case
|
||||
if fused_quant == 1 and hs_key == 'big':
|
||||
continue
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_kXbias = xbias
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
# disable welford update for 8bit and 16 bit smallN
|
||||
if not h_.F_kTwoPass_:
|
||||
#disable 16 bit when set args disable_16b_welford
|
||||
if args.disable_16b_welford and prec_i in types_16bit:
|
||||
h_.F_kWelford_ = False
|
||||
#disable 8bit by default
|
||||
elif prec_i in types_8bit or prec_o in types_8bit:
|
||||
h_.F_kWelford_ = False
|
||||
#disable 16bit small N
|
||||
elif prec_i in types_16bit and hs_key == '64':
|
||||
h_.F_kWelford_ = False
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, xbias, fused_add, fused_quant, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self, args) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
list_p = w_p / 'layernorm2d_fwd_blobs.txt'
|
||||
blobs = self.get_blobs(args)
|
||||
with list_p.open('w') as list_f:
|
||||
# api related file
|
||||
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
|
||||
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
|
||||
# kernel instance file
|
||||
for b in blobs:
|
||||
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
|
||||
|
||||
def gen_blobs(self, args) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
w_str = self.content_api(args)
|
||||
(w_p / (self.name_api + ".cpp")).write_text(w_str)
|
||||
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
|
||||
blobs = self.get_blobs(args)
|
||||
for b in blobs:
|
||||
(w_p / (b.name + ".cpp")).write_text(b.content)
|
||||
|
||||
def list_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
layernorm_fwd_codegen(args.working_path, args.filter).list_blobs(args)
|
||||
|
||||
|
||||
def gen_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
layernorm_fwd_codegen(args.working_path, args.filter).gen_blobs(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK layernorm kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--api",
|
||||
default='fwd[all]',
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma."
|
||||
)
|
||||
|
||||
# the directory for list_blobs/gen_blobs to write files into
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--working_path",
|
||||
default="./",
|
||||
required=False,
|
||||
help="the path where all the blobs are going to be generated"
|
||||
)
|
||||
|
||||
# this script have 2 modes
|
||||
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
|
||||
# this is useful in build system like cmake to construct source code dependency, by
|
||||
# reading the content out of this file
|
||||
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
|
||||
# like FA, only need to use this mode
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
action='store_true',
|
||||
help="list all the kernels to a file, "
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--gen_blobs",
|
||||
action='store_true',
|
||||
help="generate all kernels into different tile"
|
||||
)
|
||||
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--traits",
|
||||
default="all",
|
||||
required=False,
|
||||
help="enable/disable some feature. default generate all"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--disable_16b_welford",
|
||||
default=False,
|
||||
required=False,
|
||||
help="enable/disable welford for 16bit datatype n > 64"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# print(f'{args.list_blobs}-{args.gen_blobs}')
|
||||
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
|
||||
print('gen_blobs/list_blobs must specify only one option')
|
||||
sys.exit()
|
||||
|
||||
p = Path(args.working_path)
|
||||
if not p.exists():
|
||||
p.mkdir()
|
||||
|
||||
if args.list_blobs:
|
||||
list_blobs(args)
|
||||
else:
|
||||
gen_blobs(args)
|
||||
70
test/ck_tile/layernorm2d/layernorm2d_fwd.hpp
Normal file
70
test/ck_tile/layernorm2d/layernorm2d_fwd.hpp
Normal file
@@ -0,0 +1,70 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/layernorm2d.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename InType,
|
||||
typename OutType,
|
||||
typename SmoothSScaleDataType_,
|
||||
typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig;
|
||||
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = OutType;
|
||||
using XBiasDataType = ck_tile::half_t;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using BetaDataType = ck_tile::half_t;
|
||||
using MeanDataType = ck_tile::half_t;
|
||||
using InvStdDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct LayerNormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = OutType;
|
||||
using XBiasDataType = ck_tile::bf16_t;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using BetaDataType = ck_tile::bf16_t;
|
||||
using MeanDataType = ck_tile::bf16_t;
|
||||
using InvStdDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct layernorm2d_fwd_args : public ck_tile::Layernorm2dFwdHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct layernorm2d_fwd_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_o; // output precision
|
||||
|
||||
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
|
||||
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
|
||||
// can set arbitrary(will skip check)
|
||||
std::string prec_sm; // x-scale, used for [1*N] input smooth quant
|
||||
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
||||
|
||||
bool save_mean_var; //
|
||||
int xbias; // 0:no-bias, 1:add bias
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float layernorm2d_fwd(layernorm2d_fwd_traits, layernorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
566
test/ck_tile/layernorm2d/layernorm2d_fwd.inc
Normal file
566
test/ck_tile/layernorm2d/layernorm2d_fwd.inc
Normal file
@@ -0,0 +1,566 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "layernorm2d_fwd.hpp"
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::int8_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1.0;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
|
||||
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
|
||||
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
|
||||
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_mv", "0", "save mean/variance(invstd) or not. set to 1 in training case")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec_i", "fp16", "input precision")
|
||||
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
|
||||
.insert("prec_sm",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1")
|
||||
.insert("prec_sy",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1 or 2")
|
||||
.insert("xbias", "0", "add bias, 0:no add, 1:add bias before fadd")
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename SmoothScaleDataType,
|
||||
typename YScaleDataType,
|
||||
bool SaveMeanVar>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
x_stride = n;
|
||||
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
|
||||
if(xr_stride < 0)
|
||||
xr_stride = n;
|
||||
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
|
||||
if(y_stride < 0)
|
||||
y_stride = n;
|
||||
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
|
||||
if(yr_stride < 0)
|
||||
yr_stride = n;
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sm = arg_parser.get_str("prec_sm");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
prec_o = prec_i;
|
||||
}
|
||||
if(prec_sm == "auto")
|
||||
{
|
||||
prec_sm = "fp32";
|
||||
}
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
int xbias = arg_parser.get_int("xbias");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
if(fused_quant == 1 && prec_o != "int8" && prec_o != "fp8")
|
||||
{
|
||||
std::cout
|
||||
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
assert(x_stride >= n);
|
||||
|
||||
using TypeConfig =
|
||||
LayerNormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
using XBiasDataType = typename TypeConfig::XBiasDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using BetaDataType = typename TypeConfig::BetaDataType;
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
using MeanDataType =
|
||||
std::conditional_t<SaveMeanVar, typename TypeConfig::MeanDataType, ck_tile::null_type>;
|
||||
using InvStdDataType =
|
||||
std::conditional_t<SaveMeanVar, typename TypeConfig::InvStdDataType, ck_tile::null_type>;
|
||||
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<XBiasDataType> x_bias_host({n});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
ck_tile::HostTensor<BetaDataType> beta_host({n});
|
||||
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
|
||||
|
||||
ck_tile::HostTensor<MeanDataType> mean_host_ref({m});
|
||||
ck_tile::HostTensor<InvStdDataType> invStd_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
|
||||
|
||||
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
|
||||
ck_tile::FillUniformDistribution<XBiasDataType>{-.5f, .5f}(x_bias_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
ck_tile::FillUniformDistribution<BetaDataType>{-.5f, .5f}(beta_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_bias_buf(x_bias_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem beta_buf(beta_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
x_bias_buf.ToDevice(x_bias_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
beta_buf.ToDevice(beta_host.data());
|
||||
x_residual_buf.ToDevice(x_residual_host.data());
|
||||
sm_scale_buf.ToDevice(sm_scale_host.data());
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_o)
|
||||
{
|
||||
base_str += "|" + prec_o;
|
||||
}
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
base_str += std::string("(") + prec_sy + ")";
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
layernorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sm, prec_sy, SaveMeanVar, xbias, fused_add, fused_quant};
|
||||
|
||||
layernorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
x_bias_buf.GetDeviceBuffer(),
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
beta_buf.GetDeviceBuffer(),
|
||||
|
||||
y_buf.GetDeviceBuffer(),
|
||||
fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
nullptr, // p_mean, unsupported yet
|
||||
nullptr, // p_invStd, unsupported yet
|
||||
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
x_stride, // x row_stride
|
||||
xr_stride, // x residule row stride
|
||||
y_stride, // y row stride
|
||||
yr_stride}; // y residule row stride
|
||||
|
||||
float ave_time = layernorm2d_fwd(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
if(ave_time < 0)
|
||||
{
|
||||
std::cout << " not supported!" << std::endl << std::flush;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XBiasDataType) * n +
|
||||
sizeof(GammaDataType) * n + sizeof(BetaDataType) * n +
|
||||
sizeof(YDataType) * m * n;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
if(xbias != 0)
|
||||
{
|
||||
// add bias before fadd
|
||||
int M = x_host.mDesc.get_lengths()[0];
|
||||
int N = x_host.mDesc.get_lengths()[1];
|
||||
for(int idx_m = 0; idx_m < M; ++idx_m)
|
||||
{
|
||||
for(int idx_n = 0; idx_n < N; ++idx_n)
|
||||
{
|
||||
x_host(idx_m, idx_n) = ck_tile::type_convert<XDataType>(
|
||||
ck_tile::type_convert<ComputeDataType>(x_host(idx_m, idx_n)) +
|
||||
ck_tile::type_convert<ComputeDataType>(x_bias_host(idx_n)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(fused_add != 0)
|
||||
{
|
||||
// fused pre_add/pre_add_store
|
||||
// TODO we accumulate directly to x_host for simplcity here...
|
||||
|
||||
std::transform(x_host.mData.cbegin(),
|
||||
x_host.mData.cend(),
|
||||
x_residual_host.mData.cbegin(),
|
||||
x_host.mData.begin(),
|
||||
[](auto x_, auto r_) {
|
||||
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
|
||||
ck_tile::type_convert<ComputeDataType>(r_);
|
||||
return ck_tile::type_convert<XDataType>(o_);
|
||||
});
|
||||
}
|
||||
ck_tile::reference_layernorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType>(
|
||||
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
|
||||
|
||||
if(fused_quant != 0)
|
||||
{
|
||||
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
|
||||
int N_ = acc_.mDesc.get_lengths()[1];
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
// input smooth outlier
|
||||
acc_(m_, n_) = acc_(m_, n_) *
|
||||
ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
|
||||
}
|
||||
}
|
||||
ComputeDataType absmax = static_cast<ComputeDataType>(0);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
const auto a = ck_tile::abs(acc_(m_, n_));
|
||||
absmax = a > absmax ? a : absmax;
|
||||
}
|
||||
// printf("cpu:absmax:%f\n", absmax);
|
||||
constexpr ComputeDataType kMaxY =
|
||||
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
|
||||
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
|
||||
: 0.0;
|
||||
ComputeDataType y_scale = absmax / kMaxY;
|
||||
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
|
||||
}
|
||||
};
|
||||
|
||||
ck_tile::reference_layernorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType>(x_host,
|
||||
gamma_host,
|
||||
beta_host,
|
||||
y_host_ref,
|
||||
mean_host_ref,
|
||||
invStd_host_ref,
|
||||
epsilon,
|
||||
dquant_functor);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_layernorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
BetaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
MeanDataType,
|
||||
InvStdDataType>(
|
||||
x_host, gamma_host, beta_host, y_host_ref, mean_host_ref, invStd_host_ref, epsilon);
|
||||
}
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
|
||||
if(fused_add == 1)
|
||||
{
|
||||
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<OutDataType>();
|
||||
|
||||
if(x_stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(
|
||||
y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol);
|
||||
if(fused_add == 1)
|
||||
{
|
||||
pass &= ck_tile::check_err(y_residual_host_dev,
|
||||
x_host,
|
||||
std::string("ADD Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
|
||||
y_host_dev.begin() + i_r * y_stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
|
||||
y_host_ref.begin() + i_r * y_stride + n);
|
||||
pass &= ck_tile::check_err(y_host_dev_row,
|
||||
y_host_ref_row,
|
||||
std::string("OUT[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
if(fused_add == 1)
|
||||
{
|
||||
std::vector<YResidualDataType> y_residual_host_dev_row(
|
||||
y_residual_host_dev.begin() + i_r * yr_stride,
|
||||
y_residual_host_dev.begin() + i_r * yr_stride + n);
|
||||
std::vector<YResidualDataType> y_residual_host_ref_row(
|
||||
x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
|
||||
pass &= ck_tile::check_err(y_residual_host_dev_row,
|
||||
y_residual_host_ref_row,
|
||||
std::string("ADD[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
y_scale_buf.FromDevice(y_scale_host_dev.data());
|
||||
pass &= ck_tile::check_err(y_scale_host_dev,
|
||||
y_scale_host_ref,
|
||||
std::string("SCALE Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool dispatch_by_type(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sm = arg_parser.get_str("prec_sm");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
prec_o = prec_i;
|
||||
}
|
||||
if(prec_sm == "auto")
|
||||
{
|
||||
prec_sm = "fp32";
|
||||
}
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
int save_mv = arg_parser.get_int("save_mv");
|
||||
|
||||
// no dynamic quant case
|
||||
if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
// dynamic quant case, only in inference
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_mv)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int run_layernorm2d_fwd_combinations(std::string const& data_type)
|
||||
{
|
||||
constexpr size_t PARAM_COUNT = 20;
|
||||
char bufs[PARAM_COUNT][64];
|
||||
char* argv[PARAM_COUNT];
|
||||
|
||||
for(std::size_t i = 0; i < PARAM_COUNT; i++)
|
||||
{
|
||||
argv[i] = bufs[i];
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> fquant = {
|
||||
{}, {"-fquant=1", "-prec_o=int8"}, {"-fquant=1", "-prec_o=fp8"}};
|
||||
|
||||
std::vector<std::string> fadd = {"-fadd=0", "-fadd=1"};
|
||||
|
||||
std::vector<std::vector<std::string>> params = {
|
||||
{"-m=99", "-n=13"},
|
||||
{"-m=17", "-n=16"},
|
||||
{"-m=1", "-n=100"},
|
||||
{"-m=4", "-n=128"},
|
||||
{"-m=80", "-n=127"},
|
||||
{"-m=22", "-n=255 -stride=256"},
|
||||
{"-m=7", "-n=599"},
|
||||
{"-m=19", "-n=512"},
|
||||
{"-m=33", "-n=313 -stride=1000"},
|
||||
{"-m=11", "-n=510"},
|
||||
{"-m=171", "-n=676 -stride=818"},
|
||||
{"-m=91", "-n=636"},
|
||||
{"-m=12", "-n=768 -stride=800"},
|
||||
{"-m=100", "-n=766 -stride=812"},
|
||||
{"-m=31", "-n=1024"},
|
||||
{"-m=64", "-n=1000 -stride=1004"},
|
||||
{"-m=8", "-n=1501"},
|
||||
{"-m=3", "-n=1826"},
|
||||
{"-m=5", "-n=2040"},
|
||||
{"-m=7", "-n=2734"},
|
||||
{"-m=1", "-n=3182"},
|
||||
{"-m=9", "-n=4096"},
|
||||
{"-m=3", "-n=8192"},
|
||||
{"-m=3", "-n=9120"},
|
||||
{"-m=1", "-n=10547"},
|
||||
};
|
||||
|
||||
bool result = true;
|
||||
int argc = 0;
|
||||
std::vector<int> argc_stack;
|
||||
std::string pr_i = "-prec_i=" + data_type;
|
||||
strncpy(bufs[argc++], "layernorm2d_fwd", 64);
|
||||
strncpy(bufs[argc++], pr_i.c_str(), 64);
|
||||
argc_stack.push_back(argc);
|
||||
for(size_t fquant_idx = 0; fquant_idx < fquant.size(); fquant_idx++)
|
||||
{
|
||||
argc = argc_stack.back();
|
||||
for(size_t j = 0; j < fquant[fquant_idx].size(); j++)
|
||||
{
|
||||
strncpy(bufs[argc++], fquant[fquant_idx][j].c_str(), 64);
|
||||
}
|
||||
argc_stack.push_back(argc);
|
||||
for(size_t fadd_idx = 0; fadd_idx < fadd.size(); fadd_idx++)
|
||||
{
|
||||
argc = argc_stack.back();
|
||||
strncpy(bufs[argc++], fadd[fadd_idx].c_str(), 64);
|
||||
argc_stack.push_back(argc);
|
||||
for(size_t param_idx = 0; param_idx < params.size(); param_idx++)
|
||||
{
|
||||
argc = argc_stack.back();
|
||||
for(size_t j = 0; j < params[param_idx].size(); j++)
|
||||
{
|
||||
strncpy(bufs[argc++], params[param_idx][j].c_str(), 64);
|
||||
}
|
||||
|
||||
result = dispatch_by_type(argc, argv) && result;
|
||||
}
|
||||
argc_stack.pop_back();
|
||||
}
|
||||
argc_stack.pop_back();
|
||||
}
|
||||
argc_stack.pop_back();
|
||||
return result ? 0 : -1;
|
||||
}
|
||||
6
test/ck_tile/layernorm2d/layernorm2d_fwd_bf16.cpp
Normal file
6
test/ck_tile/layernorm2d/layernorm2d_fwd_bf16.cpp
Normal file
@@ -0,0 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd.inc"
|
||||
|
||||
int main() { return run_layernorm2d_fwd_combinations("bf16"); }
|
||||
6
test/ck_tile/layernorm2d/layernorm2d_fwd_fp16.cpp
Normal file
6
test/ck_tile/layernorm2d/layernorm2d_fwd_fp16.cpp
Normal file
@@ -0,0 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "layernorm2d_fwd.inc"
|
||||
|
||||
int main() { return run_layernorm2d_fwd_combinations("fp16"); }
|
||||
54
test/ck_tile/rmsnorm2d/CMakeLists.txt
Normal file
54
test/ck_tile/rmsnorm2d/CMakeLists.txt
Normal file
@@ -0,0 +1,54 @@
|
||||
function(create_tile_rmsnorm2d_fwd SUFFIX)
|
||||
set(TILE_RMSNORM2D_FWD "test_ck_tile_rmsnorm2d_fwd_${SUFFIX}")
|
||||
|
||||
message(DEBUG "adding ${TILE_RMSNORM2D_FWD}")
|
||||
add_test_executable(${TILE_RMSNORM2D_FWD} rmsnorm2d_fwd_${SUFFIX}.cpp)
|
||||
target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
|
||||
target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${RMSNORM2D_FWD_GEN_BLOBS})
|
||||
|
||||
set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS)
|
||||
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal --offload-compress)
|
||||
|
||||
target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS})
|
||||
endfunction()
|
||||
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
set(RMSNORM2D_FWD_KNOWN_APIS "fwd;bwd")
|
||||
set(RMSNORM2D_FWD_ENABLE_APIS "fwd" CACHE STRING
|
||||
"semicolon-separated list of APIs to generate (${RMSNORM2D_FWD_KNOWN_APIS}) & link, or \"all\".")
|
||||
if(RMSNORM2D_FWD_ENABLE_APIS STREQUAL "all")
|
||||
set(RMSNORM2D_FWD_ENABLE_APIS ${RMSNORM2D_FWD_KNOWN_APIS})
|
||||
endif()
|
||||
|
||||
# generate a list of kernels, but not actually emit files at config sta
|
||||
execute_process(
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --list_blobs
|
||||
RESULT_VARIABLE ret
|
||||
)
|
||||
if(ret AND NOT ret EQUAL 0)
|
||||
message( FATAL_ERROR "Fail to generate kernels via Python. ${ret}")
|
||||
endif()
|
||||
|
||||
file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/rmsnorm2d_fwd_blobs.txt RMSNORM2D_FWD_GEN_BLOBS)
|
||||
|
||||
add_custom_command(
|
||||
OUTPUT ${RMSNORM2D_FWD_GEN_BLOBS}
|
||||
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/generate.py
|
||||
--api ${RMSNORM2D_FWD_ENABLE_APIS} --working_path ${CMAKE_CURRENT_BINARY_DIR} --gen_blobs
|
||||
)
|
||||
|
||||
create_tile_rmsnorm2d_fwd("fp16")
|
||||
create_tile_rmsnorm2d_fwd("bf16")
|
||||
|
||||
# TODO: we have to turn off this global prop, otherwise the progress bar generated
|
||||
# by cmake will print too many files, execvp: /bin/sh: Argument list too long
|
||||
# however, this property may affect global
|
||||
# TODO: consider codegen a makefile by us
|
||||
set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)
|
||||
else()
|
||||
message(DEBUG "Skipping ck tile add_rmsnorm2d_rdquant_fwd tests for current target")
|
||||
endif()
|
||||
|
||||
715
test/ck_tile/rmsnorm2d/generate.py
Normal file
715
test/ck_tile/rmsnorm2d/generate.py
Normal file
@@ -0,0 +1,715 @@
|
||||
# SPDX-License-Identifier: MIT
|
||||
# Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
# generate kernel instances to speed up compilation
|
||||
|
||||
import argparse
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import List, Optional, Any
|
||||
import functools
|
||||
import itertools
|
||||
import copy
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
def get_if_str(idx, total, lase_else = True):
|
||||
if idx == 0:
|
||||
return 'if'
|
||||
elif idx < total - 1:
|
||||
return 'else if'
|
||||
else:
|
||||
if lase_else:
|
||||
return 'else'
|
||||
else:
|
||||
return 'else if'
|
||||
|
||||
FUSED_ADD_ENUM_STR_MAP = [
|
||||
'no',
|
||||
'pras', # pre-norm
|
||||
'pra' ] # post-norm
|
||||
|
||||
FUSED_FUSED_SWEEP_STR_MAP = [
|
||||
'no',
|
||||
'sdquant', # smooth dynamic quant
|
||||
'dquant' ] # dynamic quant (without sm_scale)
|
||||
|
||||
DATA_TYPE_MAP = {'fp32' : 'float',
|
||||
'fp16' : 'ck_tile::fp16_t',
|
||||
'bf16' : 'ck_tile::bf16_t',
|
||||
'int8' : 'ck_tile::int8_t',
|
||||
'fp8' : 'ck_tile::fp8_t'}
|
||||
|
||||
def BOOL_MAP(b_) -> str:
|
||||
if b_:
|
||||
return 'true'
|
||||
else:
|
||||
return 'false'
|
||||
|
||||
|
||||
class rmsnorm_fwd_codegen:
|
||||
API_TRAITS_DEFINE = """
|
||||
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename UnquantYDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
ck_tile::index_t kFusedAdd_ = 0,
|
||||
ck_tile::index_t kFusedQuant_ = 0>
|
||||
struct rmsnorm2d_fwd_traits_
|
||||
{
|
||||
using XDataType = ck_tile::remove_cvref_t<XDataType_>;
|
||||
using YDataType = ck_tile::remove_cvref_t<YDataType_>;
|
||||
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
|
||||
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
|
||||
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
|
||||
|
||||
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
|
||||
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
|
||||
static constexpr ck_tile::index_t total_warps =
|
||||
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
|
||||
|
||||
// num of warps along m
|
||||
static constexpr ck_tile::index_t BlockWarps_M = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
|
||||
}
|
||||
else
|
||||
{
|
||||
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
|
||||
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
|
||||
}
|
||||
}();
|
||||
|
||||
// num of warps along n
|
||||
static constexpr ck_tile::index_t BlockWarps_N = []() {
|
||||
if constexpr(is_warp_per_row)
|
||||
{
|
||||
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
|
||||
return 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
|
||||
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
|
||||
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
|
||||
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
|
||||
|
||||
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
|
||||
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
|
||||
|
||||
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
|
||||
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
|
||||
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
|
||||
using Vector = ck_tile::sequence<1, Vector_N_>;
|
||||
|
||||
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
|
||||
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kSaveInvRms = kSaveInvRms_;
|
||||
static constexpr bool kSaveUnquant = kSaveUnquant_;
|
||||
static constexpr bool kTwoPass = kTwoPass_;
|
||||
static constexpr ck_tile::index_t kFusedAdd = kFusedAdd_;
|
||||
static constexpr ck_tile::index_t kFusedQuant = kFusedQuant_;
|
||||
};
|
||||
|
||||
template <typename XDataType_,
|
||||
typename YDataType_,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_,
|
||||
typename UnquantYDataType_,
|
||||
ck_tile::index_t Repeat_M_, // each thread repeat along M
|
||||
ck_tile::index_t Repeat_N_, // each thread repeat along N
|
||||
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
|
||||
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
|
||||
ck_tile::index_t Vector_N_, // vector size along N
|
||||
bool kPadN_,
|
||||
bool kSaveInvRms_,
|
||||
bool kSaveUnquant_,
|
||||
bool kTwoPass_,
|
||||
int kFusedAdd_,
|
||||
int kFusedQuant_>
|
||||
using traits_ = rmsnorm2d_fwd_traits_<XDataType_,
|
||||
YDataType_,
|
||||
SmoothScaleDataType_,
|
||||
YScaleDataType_,
|
||||
UnquantYDataType_,
|
||||
Repeat_M_,
|
||||
Repeat_N_,
|
||||
ThreadPerBlock_M_,
|
||||
ThreadPerBlock_N_,
|
||||
Vector_N_,
|
||||
kPadN_,
|
||||
kSaveInvRms_,
|
||||
kSaveUnquant_,
|
||||
kTwoPass_,
|
||||
kFusedAdd_,
|
||||
kFusedQuant_>;
|
||||
"""
|
||||
|
||||
API_COMMON_HEADER = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
#include <ck_tile/ops/epilogue.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#pragma once
|
||||
|
||||
using S = ck_tile::stream_config;
|
||||
using A = rmsnorm2d_fwd_args;
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
template <typename Traits_>
|
||||
float rmsnorm2d_fwd_(const S& s, A a)
|
||||
{{
|
||||
using XDataType = typename Traits_::XDataType;
|
||||
using YDataType = typename Traits_::YDataType;
|
||||
using SmoothScaleDataType = typename Traits_::SmoothScaleDataType;
|
||||
using YScaleDataType = typename Traits_::YScaleDataType;
|
||||
using UnquantYDataType = typename Traits_::UnquantYDataType;
|
||||
using ComputeDataType = typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType;
|
||||
|
||||
using PipelineTraits =
|
||||
ck_tile::Rmsnorm2dFwdTraits<Traits_::kPadN,
|
||||
Traits_::kSaveInvRms,
|
||||
Traits_::kSaveUnquant,
|
||||
Traits_::kTwoPass,
|
||||
static_cast<ck_tile::Rmsnorm2dFusedAddEnum>(Traits_::kFusedAdd),
|
||||
static_cast<ck_tile::Rmsnorm2dFusedQuantEnum>(Traits_::kFusedQuant)>;
|
||||
|
||||
using PipelineProblem =
|
||||
ck_tile::Rmsnorm2dFwdPipelineProblem<typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::XDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::GammaDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::ComputeDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::InvRmsDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::UnquantYDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::SmoothScaleDataType,
|
||||
typename RmsnormTypeConfig<XDataType, YDataType, SmoothScaleDataType, YScaleDataType>::YScaleDataType,
|
||||
typename Traits_::Shape,
|
||||
PipelineTraits>;
|
||||
|
||||
using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass<PipelineProblem>;
|
||||
using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass<PipelineProblem>;
|
||||
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
|
||||
|
||||
using Default2DEpilogueProblem = ck_tile::Default2DEpilogueProblem<ComputeDataType, YDataType, false, Traits_::kPadN, false>;
|
||||
using Default2DEpilogue = ck_tile::Default2DEpilogue<Default2DEpilogueProblem>;
|
||||
|
||||
static constexpr bool UseSmoothInputScale = Traits_::kFusedQuant == 1;
|
||||
using DynamicQuantEpilogueProblem = ck_tile::DynamicQuantEpilogueProblem<ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, typename Traits_::Shape,
|
||||
ck_tile::DynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
|
||||
|
||||
using DynamicQuantEpilogue = ck_tile::DynamicQuantEpilogue<DynamicQuantEpilogueProblem>;
|
||||
|
||||
using Default2DAndDynamicQuantEpilogueProblem = ck_tile::Default2DAndDynamicQuantEpilogueProblem<
|
||||
ComputeDataType, SmoothScaleDataType, YScaleDataType, YDataType, UnquantYDataType, typename Traits_::Shape,
|
||||
ck_tile::Default2DAndDynamicQuantEpilogueTraits<false, Traits_::kPadN, UseSmoothInputScale, false, true/*max3*/>>;
|
||||
using Default2DAndDynamicQuantEpilogue = ck_tile::Default2DAndDynamicQuantEpilogue<Default2DAndDynamicQuantEpilogueProblem>;
|
||||
|
||||
using Epilogue = std::conditional_t<Traits_::kFusedQuant != 0,
|
||||
std::conditional_t<Traits_::kSaveUnquant,
|
||||
Default2DAndDynamicQuantEpilogue,
|
||||
DynamicQuantEpilogue>,
|
||||
Default2DEpilogue>;
|
||||
|
||||
using Kernel = ck_tile::Rmsnorm2dFwd<Pipeline, Epilogue>;
|
||||
|
||||
const dim3 grids = Kernel::GridSize(a);
|
||||
constexpr dim3 blocks = Kernel::BlockSize();
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
auto kargs = Kernel::MakeKargs(a);
|
||||
if(s.log_level_ > 0)
|
||||
std::cout << ", " << Kernel::GetName() << std::flush;
|
||||
|
||||
return ck_tile::launch_kernel(
|
||||
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{{}}, grids, blocks, 0, kargs));
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
API_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <ck_tile/core.hpp>
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
|
||||
{F_traits_define}
|
||||
|
||||
// Note: this internal API only declare, not define here, otherwise will block `make -j`
|
||||
template <typename Traits_>
|
||||
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
|
||||
rmsnorm2d_fwd_args a,
|
||||
const ck_tile::stream_config& s)
|
||||
{{
|
||||
float r = -1;
|
||||
{F_dispatch}
|
||||
return r;
|
||||
}}
|
||||
|
||||
"""
|
||||
|
||||
INSTANCE_BASE = """
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd_api_common.hpp"
|
||||
|
||||
// clang-format off
|
||||
// rm rn tm tn vn pd rms 2p
|
||||
{F_instance_def}
|
||||
// clang-format on
|
||||
|
||||
"""
|
||||
|
||||
API_PER_DTYPE = """
|
||||
{F_if}(t.prec_i == \"{F_i_type}\" && t.prec_o == \"{F_o_type}\"){{
|
||||
{F_per_n_case}
|
||||
}}
|
||||
"""
|
||||
API_PER_N_CASE = """
|
||||
{F_if} {F_N_COND} {{
|
||||
{F_inner_dispatch}
|
||||
}}
|
||||
"""
|
||||
API_INNER_CASE = """
|
||||
{F_if} {F_VEC_COND}
|
||||
r={F_instance_func}(s, a);
|
||||
"""
|
||||
|
||||
def __init__(self, working_path, kernel_filter):
|
||||
self.working_path = working_path
|
||||
self.kernel_filter = kernel_filter
|
||||
|
||||
class k_fuesd_add_enum(IntEnum):
|
||||
F_NO_ADD = 0
|
||||
F_PRE_ADD = 1
|
||||
F_PRE_ADD_STORE_RESIDUAL = 2
|
||||
|
||||
class k_fused_sweep_enum(IntEnum):
|
||||
F_NO_SWEEP = 0
|
||||
F_RENORM = 1
|
||||
F_DYNAMIC_QUANT = 2
|
||||
|
||||
@dataclass
|
||||
class k_traits:
|
||||
F_kPadN : bool
|
||||
F_kSaveMeanInvStd : bool
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : Any
|
||||
F_kFusedQuant : Any
|
||||
|
||||
@dataclass
|
||||
class k_shape:
|
||||
F_BlockTile : List[int]
|
||||
F_WarpPerBlock : List[int]
|
||||
F_WarpTile : List[int]
|
||||
F_Vector_ : List[int]
|
||||
@property
|
||||
def F_BlockSize(self) -> int:
|
||||
return functools.reduce(lambda a, b: a*b, self.F_WarpTile)
|
||||
|
||||
@dataclass
|
||||
class k_problem:
|
||||
F_XDataType : str
|
||||
F_GammaDataType : str
|
||||
F_ComputeDataType : str
|
||||
F_YDataType : str
|
||||
F_InvRmsDataType : str
|
||||
F_BlockShape : str
|
||||
F_Traits : Any #k_traits
|
||||
|
||||
@dataclass
|
||||
class k_pipeline_one_pass:
|
||||
F_Problem : Any #k_problem
|
||||
|
||||
@dataclass
|
||||
class k_pipeline_two_pass:
|
||||
F_Problem : Any #k_problem
|
||||
|
||||
@dataclass
|
||||
class default_2d_epilogue_problem:
|
||||
F_AccDataType : str
|
||||
F_ODataType : str
|
||||
F_kPadM : bool
|
||||
F_kPadN : bool
|
||||
|
||||
@dataclass
|
||||
class default_2d_epilogue:
|
||||
F_problem : Any
|
||||
|
||||
@dataclass
|
||||
class k_kernel:
|
||||
F_pipeline : Any
|
||||
F_epilogue : Any
|
||||
|
||||
@dataclass
|
||||
class h_traits:
|
||||
F_XDataType : str
|
||||
F_YDataType : str
|
||||
F_SmoothScaleDataType : str
|
||||
F_YScaleDataType : str
|
||||
F_UnquantYDataType : str
|
||||
F_Repeat_M : int
|
||||
F_Repeat_N : int
|
||||
F_ThreadPerBlock_M : int
|
||||
F_ThreadPerBlock_N : int
|
||||
F_Vector_N : int
|
||||
F_kPadN : bool
|
||||
F_kSaveInvRms : bool
|
||||
F_kSaveUnquant: bool
|
||||
F_kTwoPass : bool
|
||||
F_kFusedAdd : int
|
||||
F_kFusedQuant : int
|
||||
|
||||
@property
|
||||
def trait_name(self) ->str:
|
||||
t_ = f'{DATA_TYPE_MAP[self.F_XDataType]}, {DATA_TYPE_MAP[self.F_YDataType]}, {DATA_TYPE_MAP[self.F_SmoothScaleDataType]}, {DATA_TYPE_MAP[self.F_YScaleDataType]}, {DATA_TYPE_MAP[self.F_UnquantYDataType]}, {self.F_Repeat_M:2}, {self.F_Repeat_N:2}, {self.F_ThreadPerBlock_M:2}, {self.F_ThreadPerBlock_N:4}'
|
||||
t_ += f', {self.F_Vector_N:2}, {BOOL_MAP(self.F_kPadN):5}, {BOOL_MAP(self.F_kSaveInvRms):5}, {BOOL_MAP(self.F_kSaveUnquant):5}'
|
||||
t_ += f', {BOOL_MAP(self.F_kTwoPass):5}, {self.F_kFusedAdd:4}, {self.F_kFusedQuant:4}'
|
||||
return t_
|
||||
|
||||
# string when calling this kernel
|
||||
@property
|
||||
def call_name(self) -> str:
|
||||
return f'rmsnorm2d_fwd_<traits_<{self.trait_name}>>'
|
||||
|
||||
# string when define this kernel
|
||||
@property
|
||||
def def_name(self) -> str:
|
||||
return f'template float rmsnorm2d_fwd_<traits_<{self.trait_name}>>(const S&, A);'
|
||||
|
||||
# this class hold kernel under same source file
|
||||
@dataclass
|
||||
class h_instance:
|
||||
F_DataTypePair : str
|
||||
F_N : str
|
||||
F_add : int
|
||||
F_sweep : int
|
||||
F_saveunquant : bool
|
||||
instance_list : List[Any] # List[h_traits]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
prec_i, prec_o = self.F_DataTypePair.split(',')
|
||||
dtype_str = f'{prec_i}' if prec_i == prec_o else f'{prec_i}_{prec_o}'
|
||||
nnn = f'rmsnorm2d_fwd_{dtype_str}_n{self.F_N}'
|
||||
if self.F_add != 0:
|
||||
nnn = nnn + '_' + FUSED_ADD_ENUM_STR_MAP[self.F_add]
|
||||
if self.F_sweep != 0:
|
||||
nnn = nnn + '_' + FUSED_FUSED_SWEEP_STR_MAP[self.F_sweep]
|
||||
if self.F_saveunquant:
|
||||
nnn = nnn + '_saveunquant'
|
||||
return nnn
|
||||
|
||||
@property
|
||||
def instance_name(self) ->str:
|
||||
return self.name
|
||||
|
||||
@property
|
||||
def content(self) ->str:
|
||||
instance_defs = ''
|
||||
for ins in self.instance_list:
|
||||
instance_defs += ins.def_name + '\n'
|
||||
return rmsnorm_fwd_codegen.INSTANCE_BASE.format(F_instance_def=instance_defs)
|
||||
|
||||
@property
|
||||
def name_api(self) -> str:
|
||||
return 'rmsnorm2d_fwd_api'
|
||||
|
||||
@property
|
||||
def name_common_header(self) -> str:
|
||||
return 'rmsnorm2d_fwd_api_common'
|
||||
|
||||
@property
|
||||
def content_api(self) -> str:
|
||||
# 1 sort based on dtype
|
||||
t_dtype_dict = dict()
|
||||
blobs = self.get_blobs()
|
||||
for blob in blobs:
|
||||
if blob.F_DataTypePair not in t_dtype_dict:
|
||||
t_dtype_dict[blob.F_DataTypePair] = {}
|
||||
if blob.F_N not in t_dtype_dict[blob.F_DataTypePair]:
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_N] = []
|
||||
t_dtype_dict[blob.F_DataTypePair][blob.F_N].append(blob)
|
||||
|
||||
d_str = ''
|
||||
for i_d, dtype_ in enumerate(t_dtype_dict):
|
||||
blob_per_t = t_dtype_dict[dtype_]
|
||||
n_str = ''
|
||||
for i_n, n_ in enumerate(blob_per_t):
|
||||
blob_per_n = blob_per_t[n_]
|
||||
inner_str = ""
|
||||
for i_b, b_ in enumerate(blob_per_n):
|
||||
# generate single kernel instance file
|
||||
#vec_str = ""
|
||||
for i_ins, ins in enumerate(b_.instance_list):
|
||||
idx_in_n = i_b * len(b_.instance_list) + i_ins
|
||||
len_in_n = len(blob_per_n) * len(b_.instance_list)
|
||||
# _if = 'if' if i_ins == 0 else 'else if'
|
||||
if ins.F_kFusedQuant == 0:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep}'.format(f_fused_sweep = ins.F_kFusedQuant)
|
||||
elif ins.F_kFusedQuant == 1:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sm == \"{f_sx_type}\" && t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sx_type=ins.F_SmoothScaleDataType, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant))
|
||||
elif ins.F_kFusedQuant == 2:
|
||||
_sweep_cond = 't.fused_quant == {f_fused_sweep} && (t.prec_sy == \"{f_sy_type}\" && t.save_unquant == {f_suq})'.format(
|
||||
f_fused_sweep = ins.F_kFusedQuant, f_sy_type=ins.F_YScaleDataType, f_suq=BOOL_MAP(ins.F_kSaveUnquant))
|
||||
_cond = '((a.n % {f_vec_n} == 0) && (t.fused_add == {f_fused_add}) && ({f_sweep_cond}))'.format(
|
||||
f_vec_n = ins.F_Vector_N, f_fused_add = ins.F_kFusedAdd,
|
||||
f_sweep_cond = _sweep_cond)
|
||||
inner_str += self.API_INNER_CASE.format(F_if = get_if_str(idx_in_n, len_in_n, False),
|
||||
F_VEC_COND = _cond, F_instance_func=ins.call_name)
|
||||
#inner_str = inner_str + vec_str
|
||||
n_cnd = f'(a.n <= {n_})' if (i_n < len(blob_per_t) - 1) else ''
|
||||
n_str += self.API_PER_N_CASE.format(F_if = get_if_str(i_n, len(blob_per_t)), F_N_COND=n_cnd, F_inner_dispatch=inner_str)
|
||||
prec_i, prec_o = dtype_.split(',')
|
||||
d_str += self.API_PER_DTYPE.format(F_if = get_if_str(i_d, len(t_dtype_dict), False), F_i_type=prec_i, F_o_type=prec_o, F_per_n_case=n_str)
|
||||
|
||||
api_base = self.API_BASE.format(F_traits_define=self.API_TRAITS_DEFINE, F_dispatch=d_str)
|
||||
return api_base
|
||||
|
||||
@property
|
||||
def content_common_header(self) -> str:
|
||||
return self.API_COMMON_HEADER.format(F_traits_define=self.API_TRAITS_DEFINE)
|
||||
|
||||
def get_blobs(self):
|
||||
h_traits = rmsnorm_fwd_codegen.h_traits
|
||||
h_instance = rmsnorm_fwd_codegen.h_instance
|
||||
|
||||
dynamic_quant_out_dtype = ['int8', 'fp8']
|
||||
# some predefined support range
|
||||
# (prec_i,prec_o) for simplicity this string will be used as key for dict
|
||||
scale_list = [('fp32,fp32')]
|
||||
dtype_list = [('fp16,fp16'), ('bf16,bf16'),
|
||||
('fp16,int8'), ('bf16,int8'),
|
||||
('fp16,fp8'), ('bf16,fp8')] # NOTE: only fused-dynamic-quant use int8 out
|
||||
#fused_add_list = [0, 1, 2]
|
||||
#fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
fused_add_list = [0, 1]
|
||||
fused_sweep_list = [0, 1, 2] # NOTE: only single pass can use fused (smooth) dynamic quant
|
||||
bool_list = [False, True]
|
||||
|
||||
# rm rn tm tn vn pd mv unquant 2p add sweep
|
||||
h_trait_dict = {'64' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 8, 8, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'128' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 16, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'256' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'512' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'640' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 5, 4, 128, 1, True, False, False, False, 0, 0)],
|
||||
'768' : [ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 4, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 4, 64, 1, True, False, False, False, 0, 0)],
|
||||
'1024' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 2, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 2, 64, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 2, 64, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'1536' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 4, 64, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 2, 128, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'2048' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1, 256, 1, True, False, False, False, 0, 0)],
|
||||
'3072' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 128, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1, 256, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'4096' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 2, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'6144' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 3, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 6, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'8192' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 8, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 512, 4, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 2, True, False, False, False, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 8, 1,1024, 1, True, False, False, False, 0, 0)],
|
||||
'big' :[ h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 1, 1,1024, 8, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0),
|
||||
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0)]}
|
||||
total_blob = list()
|
||||
for hs_key in h_trait_dict:
|
||||
hs = h_trait_dict[hs_key]
|
||||
current_n = hs[0].F_Repeat_N * hs[0].F_ThreadPerBlock_N * hs[0].F_Vector_N
|
||||
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
|
||||
prec_i, prec_o = dtype.split(',')
|
||||
scale_sm, scale_y = scale_type.split(',')
|
||||
if prec_o in dynamic_quant_out_dtype and fused_quant != 1 and fused_quant != 2:
|
||||
continue # skip non dynamic quant case
|
||||
if (fused_quant == 1 or fused_quant == 2) and hs_key == 'big':
|
||||
continue
|
||||
if (fused_quant == 0 and save_unquant == True):
|
||||
continue # save_unquant should always be false when there is no quant enabled
|
||||
current_hs = list()
|
||||
for chs_ in hs:
|
||||
h_ = copy.copy(chs_) # copy the base instance out
|
||||
h_.F_XDataType = prec_i
|
||||
h_.F_YDataType = prec_o
|
||||
h_.F_SmoothScaleDataType = scale_sm
|
||||
h_.F_YScaleDataType = scale_y
|
||||
h_.F_UnquantYDataType = prec_i
|
||||
h_.F_kFusedAdd = fused_add
|
||||
h_.F_kFusedQuant = fused_quant
|
||||
h_.F_kSaveUnquant = save_unquant
|
||||
current_hs.append(h_) # + "\n"
|
||||
#f.write(str(f.parent / GEN_DIR / (blobs.api_common_header_
|
||||
current_n_str = 'big' if hs_key == 'big' else current_n
|
||||
total_blob.append(h_instance(dtype, current_n_str, fused_add, fused_quant, save_unquant, current_hs))
|
||||
return total_blob
|
||||
|
||||
def list_blobs(self) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
list_p = w_p / 'rmsnorm2d_fwd_blobs.txt'
|
||||
blobs = self.get_blobs()
|
||||
with list_p.open('w') as list_f:
|
||||
# api related file
|
||||
list_f.write(str(w_p / (self.name_api + ".cpp")) + "\n")
|
||||
list_f.write(str(w_p / (self.name_common_header + ".hpp")) + "\n")
|
||||
# kernel instance file
|
||||
for b in blobs:
|
||||
list_f.write(str(w_p / (b.name + ".cpp")) + "\n")
|
||||
|
||||
def gen_blobs(self) -> None:
|
||||
w_p = Path(self.working_path)
|
||||
(w_p / (self.name_api + ".cpp")).write_text(self.content_api)
|
||||
(w_p / (self.name_common_header + ".hpp")).write_text(self.content_common_header)
|
||||
blobs = self.get_blobs()
|
||||
for b in blobs:
|
||||
(w_p / (b.name + ".cpp")).write_text(b.content)
|
||||
|
||||
|
||||
def list_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
rmsnorm_fwd_codegen(args.working_path, args.filter).list_blobs()
|
||||
|
||||
|
||||
def gen_blobs(args):
|
||||
api_list = args.api.split(',')
|
||||
for api in api_list:
|
||||
if api == 'fwd':
|
||||
rmsnorm_fwd_codegen(args.working_path, args.filter).gen_blobs()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="generate",
|
||||
description="gen API for CK rmsnorm kernel",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--api",
|
||||
default='fwd[all]',
|
||||
required=False,
|
||||
help="supply API(s) to generate (default: fwd). separated by comma."
|
||||
)
|
||||
|
||||
# the directory for list_blobs/gen_blobs to write files into
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--working_path",
|
||||
default="./",
|
||||
required=False,
|
||||
help="the path where all the blobs are going to be generated"
|
||||
)
|
||||
|
||||
# this script have 2 modes
|
||||
# 1) list_blobs mode, will generate a txt file with all the files going to be generated.
|
||||
# this is useful in build system like cmake to construct source code dependency, by
|
||||
# reading the content out of this file
|
||||
# 2) gen_blobs mode, will generate the actuall kernel instance and api. If in framework
|
||||
# like FA, only need to use this mode
|
||||
parser.add_argument(
|
||||
"-l",
|
||||
"--list_blobs",
|
||||
action='store_true',
|
||||
help="list all the kernels to a file, "
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-g",
|
||||
"--gen_blobs",
|
||||
action='store_true',
|
||||
help="generate all kernels into different tile"
|
||||
)
|
||||
|
||||
# TODO: if using filter, must apply same value to output_dir and list_blobs
|
||||
parser.add_argument(
|
||||
"-f",
|
||||
"--filter",
|
||||
required=False,
|
||||
help="filter out kernels that need to generate, using fnmatch module"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--traits",
|
||||
default="all",
|
||||
required=False,
|
||||
help="enable/disable some feature. default generate all"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--receipt",
|
||||
default=0,
|
||||
required=False,
|
||||
help="codegen receipt."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# print(f'{args.list_blobs}-{args.gen_blobs}')
|
||||
if (args.gen_blobs and args.list_blobs) or ((not args.gen_blobs) and (not args.list_blobs)):
|
||||
print('gen_blobs/list_blobs must specify only one option')
|
||||
sys.exit()
|
||||
|
||||
p = Path(args.working_path)
|
||||
if not p.exists():
|
||||
p.mkdir()
|
||||
|
||||
if args.list_blobs:
|
||||
list_blobs(args)
|
||||
else:
|
||||
gen_blobs(args)
|
||||
69
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.hpp
Normal file
69
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/rmsnorm2d.hpp"
|
||||
#include <string>
|
||||
|
||||
template <typename InType,
|
||||
typename OutType,
|
||||
typename SmoothScaleDataType_,
|
||||
typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig;
|
||||
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig<ck_tile::half_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::half_t;
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::half_t;
|
||||
using InvRmsDataType = ck_tile::half_t;
|
||||
using UnquantYDataType = ck_tile::half_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
template <typename OutType, typename SmoothScaleDataType_, typename YScaleDataType_>
|
||||
struct RmsnormTypeConfig<ck_tile::bf16_t, OutType, SmoothScaleDataType_, YScaleDataType_>
|
||||
{
|
||||
using XDataType = ck_tile::bf16_t;
|
||||
using YDataType = OutType;
|
||||
using GammaDataType = ck_tile::bf16_t;
|
||||
using InvRmsDataType = ck_tile::bf16_t;
|
||||
using UnquantYDataType = ck_tile::bf16_t;
|
||||
using ComputeDataType = float;
|
||||
using SmoothScaleDataType = SmoothScaleDataType_;
|
||||
using YScaleDataType = YScaleDataType_;
|
||||
};
|
||||
|
||||
// runtime args
|
||||
struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Traits_>
|
||||
float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a);
|
||||
|
||||
// This is the public API, will be generated by script
|
||||
struct rmsnorm2d_fwd_traits
|
||||
{
|
||||
std::string prec_i; // input precision
|
||||
std::string prec_o; // output precision
|
||||
|
||||
// if fused_quant == 1, need set prec_sm/prec_sy to proper string, otherwise can set
|
||||
// arbitrary(will skip check) if fused_quant == 2, need set prec_sy to proper string, otherwise
|
||||
// can set arbitrary(will skip check)
|
||||
std::string prec_sm; // x-scale, used for [1*N] input smooth quant
|
||||
std::string prec_sy; // y-scale, used for [M*1] output for next layer
|
||||
|
||||
bool save_rms;
|
||||
bool save_unquant;
|
||||
int fused_add; // 0:no-add, 1:pre-add-store, 2:pre-add
|
||||
int fused_quant; // 0:no-sweep, 1:smooth-dynamic-quant, 2:dynamic-quant
|
||||
};
|
||||
|
||||
float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&);
|
||||
619
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc
Normal file
619
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd.inc
Normal file
@@ -0,0 +1,619 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "rmsnorm2d_fwd.hpp"
|
||||
#include <cstring>
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>()
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::int8_t>()
|
||||
{
|
||||
double rtol = 1e-02;
|
||||
double atol = 1.0;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "3328", "m dimension")
|
||||
.insert("n", "4096", "n dimension")
|
||||
.insert("x_stride", "-1", "x row_stride, if -1 then equal to n")
|
||||
.insert("xr_stride", "-1", "x residule row_stride, if -1 then equal to n")
|
||||
.insert("y_stride", "-1", "y row_stride, if -1 then equal to n")
|
||||
.insert("yr_stride", "-1", "y residule row_stride, if -1 then equal to n")
|
||||
.insert("e", "1e-5", "epsilon")
|
||||
.insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case")
|
||||
.insert("save_unquant", "0", "save result before quant")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("kname", "1", "print kernel name or not")
|
||||
.insert("prec_i", "fp16", "input precision")
|
||||
.insert("prec_o", "auto", "output precision, set auto will be the same as input")
|
||||
.insert("prec_sm",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1")
|
||||
.insert("prec_sy",
|
||||
"auto",
|
||||
"output quant scale type, set auto will use fp32. used when fquant=1 or 2")
|
||||
.insert("fadd", "0", "fused-add, 0:no fused add, 1:preadd+store, 2:preadd only")
|
||||
.insert("fquant", "0", "fused-quant, 0:no, 1:smooth-dynamic-quant, 2:dynamic-quant")
|
||||
.insert("warmup", "5", "cold iter")
|
||||
.insert("repeat", "20", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InDataType,
|
||||
typename OutDataType,
|
||||
typename SmoothScaleDataType,
|
||||
typename YScaleDataType,
|
||||
bool SaveRms,
|
||||
bool SaveUnquant>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t m = arg_parser.get_int("m");
|
||||
ck_tile::index_t n = arg_parser.get_int("n");
|
||||
float epsilon = arg_parser.get_float("e");
|
||||
int kname = arg_parser.get_int("kname");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int fused_add = arg_parser.get_int("fadd");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
|
||||
if(x_stride < 0)
|
||||
x_stride = n;
|
||||
ck_tile::index_t xr_stride = arg_parser.get_int("xr_stride");
|
||||
if(xr_stride < 0)
|
||||
xr_stride = n;
|
||||
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
|
||||
if(y_stride < 0)
|
||||
y_stride = n;
|
||||
ck_tile::index_t yr_stride = arg_parser.get_int("yr_stride");
|
||||
if(yr_stride < 0)
|
||||
yr_stride = n;
|
||||
assert(x_stride >= n);
|
||||
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sm = arg_parser.get_str("prec_sm");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
prec_o = prec_i;
|
||||
}
|
||||
if(prec_sm == "auto")
|
||||
{
|
||||
prec_sm = "fp32";
|
||||
}
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
|
||||
if((fused_quant == 1 || fused_quant == 2) && prec_o != "int8" && prec_o != "fp8")
|
||||
{
|
||||
std::cout
|
||||
<< "if fused_quant is 1 or 2, only support \"-prec_o=int8\" or \"-prec_o=fp8\" cases."
|
||||
<< std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
if((fused_quant == 0) && SaveUnquant)
|
||||
{
|
||||
std::cout
|
||||
<< "save_unquant should be 0 if quant output is not enabled because it is meaningless. "
|
||||
<< "Output Y is what wanted." << std::endl;
|
||||
return false;
|
||||
}
|
||||
|
||||
using TypeConfig =
|
||||
RmsnormTypeConfig<InDataType, OutDataType, SmoothScaleDataType, YScaleDataType>;
|
||||
|
||||
using XDataType = typename TypeConfig::XDataType;
|
||||
using YDataType = typename TypeConfig::YDataType;
|
||||
using GammaDataType = typename TypeConfig::GammaDataType;
|
||||
using XResidualDataType = XDataType;
|
||||
using YResidualDataType = XDataType;
|
||||
|
||||
using InvRmsDataType =
|
||||
std::conditional_t<SaveRms, typename TypeConfig::InvRmsDataType, ck_tile::null_type>;
|
||||
using UnquantYDataType =
|
||||
std::conditional_t<SaveUnquant, typename TypeConfig::UnquantYDataType, ck_tile::null_type>;
|
||||
|
||||
using ComputeDataType = typename TypeConfig::ComputeDataType;
|
||||
|
||||
// host verify
|
||||
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
|
||||
ck_tile::HostTensor<GammaDataType> gamma_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host({n});
|
||||
ck_tile::HostTensor<SmoothScaleDataType> sm_scale_host_dev({n});
|
||||
|
||||
ck_tile::HostTensor<XResidualDataType> x_residual_host({m, n}, {xr_stride, 1});
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host({m, n}, {yr_stride, 1});
|
||||
|
||||
ck_tile::HostTensor<YDataType> y_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host_dev({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_ref({m});
|
||||
ck_tile::HostTensor<YScaleDataType> y_scale_host_dev({m});
|
||||
|
||||
ck_tile::HostTensor<InvRmsDataType> invRms_host_ref({m});
|
||||
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_ref({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<UnquantYDataType> unquant_y_host_dev({m, n}, {y_stride, 1});
|
||||
ck_tile::HostTensor<ck_tile::null_type> unquant_y_null({1});
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
|
||||
ck_tile::FillUniformDistribution<XResidualDataType>{-.5f, .5f}(x_residual_host);
|
||||
ck_tile::FillUniformDistribution<SmoothScaleDataType>{-1.f, 1.f}(sm_scale_host);
|
||||
ck_tile::FillUniformDistribution<GammaDataType>{-.5f, .5f}(gamma_host);
|
||||
|
||||
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_scale_buf(y_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem sm_scale_buf(sm_scale_host_dev.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_residual_buf(x_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_residual_buf(y_residual_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem unquant_y_buf(unquant_y_host_dev.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf.ToDevice(x_host.data());
|
||||
gamma_buf.ToDevice(gamma_host.data());
|
||||
x_residual_buf.ToDevice(x_residual_host.data());
|
||||
sm_scale_buf.ToDevice(sm_scale_host.data());
|
||||
|
||||
auto prec_str = [&]() {
|
||||
auto base_str = prec_i;
|
||||
if(prec_i != prec_o)
|
||||
{
|
||||
base_str += "|" + prec_o;
|
||||
}
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
base_str += std::string("(") + prec_sy + ")";
|
||||
}
|
||||
return base_str;
|
||||
}();
|
||||
|
||||
std::cout << "[" << prec_str << "]"
|
||||
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride
|
||||
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
|
||||
<< ", yr_stride:" << yr_stride << std::flush;
|
||||
|
||||
rmsnorm2d_fwd_traits traits{
|
||||
prec_i, prec_o, prec_sm, prec_sy, SaveRms, SaveUnquant, fused_add, fused_quant};
|
||||
|
||||
rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(),
|
||||
fused_add != 0 ? x_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant == 1 ? sm_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
gamma_buf.GetDeviceBuffer(),
|
||||
y_buf.GetDeviceBuffer(),
|
||||
fused_add == 1 ? y_residual_buf.GetDeviceBuffer() : nullptr,
|
||||
fused_quant != 0 ? y_scale_buf.GetDeviceBuffer() : nullptr,
|
||||
nullptr, // p_invRms, unsupported yet
|
||||
SaveUnquant ? unquant_y_buf.GetDeviceBuffer() : nullptr,
|
||||
epsilon,
|
||||
m,
|
||||
n,
|
||||
x_stride, // x row_stride
|
||||
xr_stride, // x residule row stride
|
||||
y_stride, // y row stride
|
||||
yr_stride}; // y residule row stride
|
||||
|
||||
float ave_time = rmsnorm2d_fwd(
|
||||
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
|
||||
|
||||
std::size_t num_byte =
|
||||
sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n;
|
||||
num_byte += SaveRms ? sizeof(InvRmsDataType) * m * n : 0;
|
||||
num_byte += SaveUnquant ? sizeof(UnquantYDataType) * m * n : 0;
|
||||
num_byte += fused_add ? sizeof(XResidualDataType) * m * n : 0;
|
||||
num_byte += ((fused_quant == 1) || (fused_quant == 2)) ? sizeof(YScaleDataType) * m : 0;
|
||||
num_byte += (fused_quant == 1) ? sizeof(SmoothScaleDataType) * n : 0;
|
||||
|
||||
float gb_per_sec = num_byte / 1.E6 / ave_time;
|
||||
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
|
||||
|
||||
bool pass = true;
|
||||
|
||||
if(do_validation)
|
||||
{
|
||||
// reference
|
||||
if(fused_add != 0)
|
||||
{
|
||||
// fused pre_add/pre_add_store
|
||||
// TODO we accumulate directly to x_host for simplcity here...
|
||||
std::transform(x_host.mData.cbegin(),
|
||||
x_host.mData.cend(),
|
||||
x_residual_host.mData.cbegin(),
|
||||
x_host.mData.begin(),
|
||||
[](auto x_, auto r_) {
|
||||
auto o_ = ck_tile::type_convert<ComputeDataType>(x_) +
|
||||
ck_tile::type_convert<ComputeDataType>(r_);
|
||||
return ck_tile::type_convert<XDataType>(o_);
|
||||
});
|
||||
}
|
||||
|
||||
if(fused_quant != 0)
|
||||
{
|
||||
auto dquant_functor = [&](int m_, auto& o_, auto& acc_) {
|
||||
int N_ = acc_.mDesc.get_lengths()[1];
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
// input smooth outlier
|
||||
acc_(m_, n_) = acc_(m_, n_) *
|
||||
ck_tile::type_convert<ComputeDataType>(sm_scale_host(n_));
|
||||
}
|
||||
}
|
||||
ComputeDataType absmax = static_cast<ComputeDataType>(0);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
const auto a = ck_tile::abs(acc_(m_, n_));
|
||||
absmax = a > absmax ? a : absmax;
|
||||
}
|
||||
// printf("cpu:absmax:%f\n", absmax);
|
||||
constexpr ComputeDataType kMaxY =
|
||||
std::is_same<YDataType, ck_tile::fp8_t>::value ? 240.0
|
||||
: std::is_same<YDataType, ck_tile::int8_t>::value ? 127.0
|
||||
: 0.0;
|
||||
ComputeDataType y_scale = absmax / kMaxY;
|
||||
y_scale_host_ref(m_) = ck_tile::type_convert<YScaleDataType>(y_scale);
|
||||
for(int n_ = 0; n_ < N_; n_++)
|
||||
{
|
||||
o_(m_, n_) = ck_tile::type_convert<YDataType>(acc_(m_, n_) / y_scale);
|
||||
}
|
||||
};
|
||||
|
||||
auto default_and_dquant_functor = [&](int m_, auto& o_unquant_, auto& o_, auto& acc_) {
|
||||
const int N = acc_.mDesc.get_lengths()[1];
|
||||
for(int n_ = 0; n_ < N; ++n_)
|
||||
{
|
||||
o_unquant_(m_, n_) = ck_tile::type_convert<OutDataType>(acc_(m_, n_));
|
||||
}
|
||||
|
||||
dquant_functor(m_, o_, acc_);
|
||||
};
|
||||
|
||||
if constexpr(SaveUnquant)
|
||||
{
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(x_host,
|
||||
gamma_host,
|
||||
y_host_ref,
|
||||
invRms_host_ref,
|
||||
unquant_y_host_ref,
|
||||
epsilon,
|
||||
default_and_dquant_functor);
|
||||
}
|
||||
else
|
||||
{
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
UnquantYDataType>(x_host,
|
||||
gamma_host,
|
||||
y_host_ref,
|
||||
invRms_host_ref,
|
||||
unquant_y_host_ref,
|
||||
epsilon,
|
||||
dquant_functor);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
assert(SaveUnquant == false);
|
||||
ck_tile::reference_rmsnorm2d_fwd<XDataType,
|
||||
GammaDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
InvRmsDataType,
|
||||
ck_tile::null_type>(
|
||||
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon);
|
||||
}
|
||||
|
||||
y_buf.FromDevice(y_host_dev.data());
|
||||
|
||||
ck_tile::HostTensor<YResidualDataType> y_residual_host_dev({m, n}, {yr_stride, 1});
|
||||
if(fused_add == 1)
|
||||
{
|
||||
y_residual_buf.FromDevice(y_residual_host_dev.data());
|
||||
}
|
||||
|
||||
auto [rtol, atol] = get_elimit<YDataType>();
|
||||
if(x_stride == n)
|
||||
{
|
||||
pass = ck_tile::check_err(
|
||||
y_host_dev, y_host_ref, std::string("\nOUT Error: Incorrect results!"), rtol, atol);
|
||||
|
||||
if constexpr(SaveUnquant)
|
||||
{
|
||||
pass &= ck_tile::check_err(unquant_y_host_dev,
|
||||
unquant_y_host_ref,
|
||||
std::string("\n OUT ERROR: Incorrect unquant results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
if(fused_add == 1)
|
||||
{
|
||||
pass &= ck_tile::check_err(y_residual_host_dev,
|
||||
x_host,
|
||||
std::string("\nADD Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(int i_r = 0; i_r < m; i_r++)
|
||||
{
|
||||
std::vector<YDataType> y_host_dev_row(y_host_dev.begin() + i_r * y_stride,
|
||||
y_host_dev.begin() + i_r * y_stride + n);
|
||||
std::vector<YDataType> y_host_ref_row(y_host_ref.begin() + i_r * y_stride,
|
||||
y_host_ref.begin() + i_r * y_stride + n);
|
||||
pass &= ck_tile::check_err(y_host_dev_row,
|
||||
y_host_ref_row,
|
||||
std::string("\nOUT[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
|
||||
if(fused_add == 1)
|
||||
{
|
||||
std::vector<YResidualDataType> y_residual_host_dev_row(
|
||||
y_residual_host_dev.begin() + i_r * yr_stride,
|
||||
y_residual_host_dev.begin() + i_r * yr_stride + n);
|
||||
std::vector<YResidualDataType> y_residual_host_ref_row(
|
||||
x_host.begin() + i_r * yr_stride, x_host.begin() + i_r * yr_stride + n);
|
||||
pass &= ck_tile::check_err(y_residual_host_dev_row,
|
||||
y_residual_host_ref_row,
|
||||
std::string("\nADD[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
if constexpr(SaveUnquant)
|
||||
{
|
||||
std::vector<UnquantYDataType> unquant_y_host_dev_row(
|
||||
unquant_y_host_dev.begin() + i_r * y_stride,
|
||||
unquant_y_host_dev.begin() + i_r * y_stride + n);
|
||||
std::vector<UnquantYDataType> unquant_y_host_ref_row(
|
||||
unquant_y_host_ref.begin() + i_r * y_stride,
|
||||
unquant_y_host_ref.begin() + i_r * y_stride + n);
|
||||
pass &=
|
||||
ck_tile::check_err(unquant_y_host_dev_row,
|
||||
unquant_y_host_ref_row,
|
||||
std::string("\nOUT[") + std::to_string(i_r) +
|
||||
std::string("] Error: Incorrect unquant y results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if(fused_quant == 1)
|
||||
{
|
||||
y_scale_buf.FromDevice(y_scale_host_dev.data());
|
||||
pass &= ck_tile::check_err(y_scale_host_dev,
|
||||
y_scale_host_ref,
|
||||
std::string("\nSCALE Error: Incorrect results!"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
|
||||
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
bool is_quant_data_type(const std::string& prec) { return (prec == "int8") || (prec == "fp8"); }
|
||||
|
||||
bool dispatch_by_type(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return false;
|
||||
|
||||
std::string prec_i = arg_parser.get_str("prec_i");
|
||||
std::string prec_o = arg_parser.get_str("prec_o");
|
||||
std::string prec_sm = arg_parser.get_str("prec_sm");
|
||||
std::string prec_sy = arg_parser.get_str("prec_sy");
|
||||
if(prec_o == "auto")
|
||||
{
|
||||
prec_o = prec_i;
|
||||
}
|
||||
if(prec_sm == "auto")
|
||||
{
|
||||
prec_sm = "fp32";
|
||||
}
|
||||
if(prec_sy == "auto")
|
||||
{
|
||||
prec_sy = "fp32";
|
||||
}
|
||||
|
||||
int save_rms = arg_parser.get_int("save_rms");
|
||||
int fused_quant = arg_parser.get_int("fquant");
|
||||
int save_unquant =
|
||||
arg_parser.get_int("save_unquant") && is_quant_data_type(prec_o) && (fused_quant != 0);
|
||||
|
||||
if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" && save_rms)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, true, false>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::half_t, float, float, false, false>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, true, false>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "bf16" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::bf16_t, float, float, false, false>(arg_parser);
|
||||
}
|
||||
|
||||
// dynamic quant case, only in inference
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && !save_unquant)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, true, false>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && !save_unquant)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true, false>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && !save_unquant)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false, false>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && !save_unquant)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false, false>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && save_unquant)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::int8_t, float, float, true, true>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "int8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && save_unquant)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::int8_t, float, float, true, true>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "fp16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && save_unquant)
|
||||
{
|
||||
return run<ck_tile::half_t, ck_tile::fp8_t, float, float, false, true>(arg_parser);
|
||||
}
|
||||
else if(prec_i == "bf16" && prec_o == "fp8" && prec_sm == "fp32" && prec_sy == "fp32" &&
|
||||
!save_rms && save_unquant)
|
||||
{
|
||||
return run<ck_tile::bf16_t, ck_tile::fp8_t, float, float, false, true>(arg_parser);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
int run_rmsnorm2d_fwd_combinations(std::string const& data_type)
|
||||
{
|
||||
constexpr size_t PARAM_COUNT = 20;
|
||||
char bufs[PARAM_COUNT][64];
|
||||
char* argv[PARAM_COUNT];
|
||||
|
||||
for(std::size_t i = 0; i < PARAM_COUNT; i++)
|
||||
{
|
||||
argv[i] = bufs[i];
|
||||
}
|
||||
|
||||
std::vector<std::vector<std::string>> fquant = {
|
||||
{},
|
||||
{"-fquant=1", "-prec_o=int8"},
|
||||
{"-fquant=2", "-prec_o=int8"},
|
||||
{"-fquant=1", "-prec_o=fp8"},
|
||||
{"-fquant=2", "-prec_o=fp8"},
|
||||
{"-fquant=1", "-prec_o=int8", "-save_unquant=1"},
|
||||
{"-fquant=2", "-prec_o=int8", "-save_unquant=1"},
|
||||
{"-fquant=1", "-prec_o=fp8", "-save_unquant=1"},
|
||||
{"-fquant=2", "-prec_o=fp8", "-save_unquant=1"}};
|
||||
|
||||
std::vector<std::string> fadd = {"-fadd=0", "-fadd=1"};
|
||||
|
||||
std::vector<std::vector<std::string>> params = {
|
||||
{"-m=99", "-n=13"},
|
||||
{"-m=17", "-n=16"},
|
||||
{"-m=1", "-n=100"},
|
||||
{"-m=4", "-n=128"},
|
||||
{"-m=80", "-n=127"},
|
||||
{"-m=22", "-n=255", "-stride=256"},
|
||||
{"-m=7", "-n=599"},
|
||||
{"-m=19", "-n=512"},
|
||||
{"-m=33", "-n=313", "-stride=1000"},
|
||||
{"-m=11", "-n=510"},
|
||||
{"-m=171", "-n=676", "-stride=818"},
|
||||
{"-m=91", "-n=636"},
|
||||
{"-m=12", "-n=768", "-stride=800"},
|
||||
{"-m=100", "-n=766", "-stride=812"},
|
||||
{"-m=31", "-n=1024"},
|
||||
{"-m=64", "-n=1000", "-stride=1004"},
|
||||
{"-m=8", "-n=1501"},
|
||||
{"-m=3", "-n=1826"},
|
||||
{"-m=5", "-n=2040"},
|
||||
{"-m=7", "-n=2734"},
|
||||
{"-m=1", "-n=3182"},
|
||||
{"-m=9", "-n=4096"},
|
||||
{"-m=3", "-n=8192"},
|
||||
};
|
||||
|
||||
bool result = true;
|
||||
int argc = 0;
|
||||
std::vector<int> argc_stack;
|
||||
std::string pr_i = "-prec_i=" + data_type;
|
||||
strncpy(bufs[argc++], "rmsnorm2d_fwd", 64);
|
||||
strncpy(bufs[argc++], pr_i.c_str(), 64);
|
||||
argc_stack.push_back(argc);
|
||||
for(size_t fquant_idx = 0; fquant_idx < fquant.size(); fquant_idx++)
|
||||
{
|
||||
argc = argc_stack.back();
|
||||
for(size_t j = 0; j < fquant[fquant_idx].size(); j++)
|
||||
{
|
||||
strncpy(bufs[argc++], fquant[fquant_idx][j].c_str(), 64);
|
||||
}
|
||||
argc_stack.push_back(argc);
|
||||
for(size_t fadd_idx = 0; fadd_idx < fadd.size(); fadd_idx++)
|
||||
{
|
||||
argc = argc_stack.back();
|
||||
strncpy(bufs[argc++], fadd[fadd_idx].c_str(), 64);
|
||||
argc_stack.push_back(argc);
|
||||
for(size_t param_idx = 0; param_idx < params.size(); param_idx++)
|
||||
{
|
||||
argc = argc_stack.back();
|
||||
for(size_t j = 0; j < params[param_idx].size(); j++)
|
||||
{
|
||||
strncpy(bufs[argc++], params[param_idx][j].c_str(), 64);
|
||||
}
|
||||
|
||||
result = dispatch_by_type(argc, argv) && result;
|
||||
}
|
||||
argc_stack.pop_back();
|
||||
}
|
||||
argc_stack.pop_back();
|
||||
}
|
||||
argc_stack.pop_back();
|
||||
return result ? 0 : -1;
|
||||
}
|
||||
5
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_bf16.cpp
Normal file
5
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_bf16.cpp
Normal file
@@ -0,0 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd.inc"
|
||||
int main() { return run_rmsnorm2d_fwd_combinations("bf16"); }
|
||||
5
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_fp16.cpp
Normal file
5
test/ck_tile/rmsnorm2d/rmsnorm2d_fwd_fp16.cpp
Normal file
@@ -0,0 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "rmsnorm2d_fwd.inc"
|
||||
int main() { return run_rmsnorm2d_fwd_combinations("fp16"); }
|
||||
19
test/ck_tile/topk_softmax/CMakeLists.txt
Normal file
19
test/ck_tile/topk_softmax/CMakeLists.txt
Normal file
@@ -0,0 +1,19 @@
|
||||
function(add_tile_topk_softmax_test SUFFIX)
|
||||
set(TEST_NAME "test_ck_tile_topk_softmax_${SUFFIX}")
|
||||
add_test_executable(${TEST_NAME} test_topk_softmax_${SUFFIX}.cpp test_topk_softmax_api.cpp)
|
||||
target_include_directories(${TEST_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/)
|
||||
|
||||
set(TEST_TOPK_SOFTMAX_COMPILE_OPTIONS)
|
||||
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
|
||||
list(APPEND TEST_TOPK_SOFTMAX_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
|
||||
# list(APPEND TEST_TOPK_SOFTMAX_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
|
||||
target_compile_options(${TEST_NAME} PRIVATE ${TEST_TOPK_SOFTMAX_COMPILE_OPTIONS})
|
||||
endfunction()
|
||||
|
||||
# Currently ck_tile is only built on gfx9
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_tile_topk_softmax_test(fp16)
|
||||
add_tile_topk_softmax_test(bf16)
|
||||
else()
|
||||
message(DEBUG "Skipping tile topk_softmax tests for current target")
|
||||
endif()
|
||||
280
test/ck_tile/topk_softmax/test_topk_softmax.hpp
Normal file
280
test/ck_tile/topk_softmax/test_topk_softmax.hpp
Normal file
@@ -0,0 +1,280 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cassert>
|
||||
#include <cstdlib>
|
||||
#include <time.h>
|
||||
|
||||
#include "test_topk_softmax_api.hpp"
|
||||
|
||||
// CPU reference
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
|
||||
|
||||
auto [y_values, y_indices] = reference_topk(y, k, dim, largest, sorted);
|
||||
|
||||
return ck_tile::make_tuple(y_values, y_indices);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
auto reference_topk_softmax(const ck_tile::HostTensor<InputType>& x,
|
||||
ck_tile::HostTensor<WeightType>& y_values,
|
||||
ck_tile::HostTensor<IndexType>& y_indices,
|
||||
ck_tile::index_t k,
|
||||
ck_tile::index_t dim = -1,
|
||||
bool largest = true,
|
||||
bool sorted = true)
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
auto y = reference_softmax<InputType, WeightType, WeightType>(x, dim);
|
||||
reference_topk(y, y_values, y_indices, k, dim, largest, sorted);
|
||||
}
|
||||
|
||||
// different threshold for different dtype
|
||||
template <typename DataType>
|
||||
auto get_elimit(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-3;
|
||||
double atol = 1e-3;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
|
||||
{
|
||||
double rtol = 1e-2;
|
||||
double atol = 1e-2;
|
||||
return ck_tile::make_tuple(rtol, atol);
|
||||
}
|
||||
|
||||
template <>
|
||||
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
|
||||
{
|
||||
if(init_method == "ui" || init_method == "ni")
|
||||
{
|
||||
unsigned max_rounding_point_distance = 0;
|
||||
double atol = 2e-3;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
else
|
||||
{
|
||||
unsigned max_rounding_point_distance = 1;
|
||||
double atol = 0.0625;
|
||||
return ck_tile::make_tuple(max_rounding_point_distance, atol);
|
||||
}
|
||||
}
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("v", "1", "weather do CPU validation or not")
|
||||
.insert("pr_i", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
|
||||
.insert("pr_w", "fp32", "output weight data type(currently only fp32 supported now)")
|
||||
.insert("t", "32", "number of input tokens")
|
||||
.insert("e", "8", "number of experts")
|
||||
.insert("k", "2", "topk")
|
||||
.insert("st_i", "-1", "row stride of input, -1 means same as experts")
|
||||
.insert("st_o", "-1", "row stride of output/indices, -1 means same as topk")
|
||||
.insert("seed", "-1", "seed to be used, -1 means random every time")
|
||||
.insert("kname", "0", "when set to 1 it will print kernel name")
|
||||
.insert("warmup", "5", "number of iterations before benchmark the kernel")
|
||||
.insert("repeat", "20", "number of iterations to benchmark the kernel");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename InputType, typename WeightType, typename IndexType = ck_tile::index_t>
|
||||
bool test_topk_softmax(ck_tile::ArgParser args)
|
||||
{
|
||||
int validate = args.get_int("v");
|
||||
std::string input_prec = args.get_str("pr_i");
|
||||
std::string weight_prec = args.get_str("pr_w");
|
||||
int tokens = args.get_int("t");
|
||||
int experts = args.get_int("e");
|
||||
int topk = args.get_int("k");
|
||||
int seed = args.get_int("seed");
|
||||
int stride_input = args.get_int("st_i");
|
||||
int stride_output = args.get_int("st_o");
|
||||
int kname = args.get_int("kname");
|
||||
int warmup = args.get_int("warmup");
|
||||
int repeat = args.get_int("repeat");
|
||||
|
||||
if(stride_input < 0)
|
||||
{
|
||||
stride_input = experts;
|
||||
}
|
||||
if(stride_output < 0)
|
||||
{
|
||||
stride_output = topk;
|
||||
}
|
||||
assert(stride_input >= experts);
|
||||
assert(stride_output >= topk);
|
||||
|
||||
if(seed < 0)
|
||||
{
|
||||
seed = std::time(nullptr);
|
||||
}
|
||||
|
||||
if(topk > experts)
|
||||
{
|
||||
printf("topk:%d value should be smaller than, or equal to number of experts:%d\n",
|
||||
topk,
|
||||
experts);
|
||||
return false;
|
||||
}
|
||||
|
||||
// tokens already considered batch size
|
||||
ck_tile::HostTensor<InputType> x_host({tokens, experts}, {stride_input, 1});
|
||||
ck_tile::HostTensor<WeightType> value_host({tokens, topk}, {stride_output, 1});
|
||||
ck_tile::HostTensor<IndexType> index_host({tokens, topk}, {stride_output, 1});
|
||||
|
||||
{
|
||||
// random require per-row unique
|
||||
auto rand_gen = ck_tile::FillUniformDistribution_Unique<InputType>{
|
||||
-5.f, 5.f, static_cast<uint32_t>(seed)};
|
||||
|
||||
for(int i_t = 0; i_t < tokens; i_t++)
|
||||
{
|
||||
ck_tile::HostTensor<InputType> x_row({experts});
|
||||
rand_gen(x_row);
|
||||
std::copy(x_row.begin(), x_row.end(), x_host.begin() + i_t * stride_input);
|
||||
rand_gen.clear();
|
||||
}
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem value_dev(value_host.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem index_dev(index_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_dev.ToDevice(x_host.data());
|
||||
|
||||
topk_softmax_trait trait{input_prec, weight_prec, experts};
|
||||
|
||||
topk_softmax_kargs karg{x_dev.GetDeviceBuffer(),
|
||||
value_dev.GetDeviceBuffer(),
|
||||
index_dev.GetDeviceBuffer(),
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride_input,
|
||||
stride_output};
|
||||
|
||||
ck_tile::stream_config sc{nullptr,
|
||||
true,
|
||||
/* log_level = */ (kname ? 1 : 0),
|
||||
warmup,
|
||||
repeat};
|
||||
auto ms = topk_softmax(trait, karg, sc);
|
||||
printf("[%s|%s]tokens:%d, experts:%d, topk:%d, st_i:%d, st_o:%d, ms:%f, ",
|
||||
input_prec.c_str(),
|
||||
weight_prec.c_str(),
|
||||
tokens,
|
||||
experts,
|
||||
topk,
|
||||
stride_input,
|
||||
stride_output,
|
||||
ms);
|
||||
if(ms < 0)
|
||||
printf("not supported\n");
|
||||
fflush(stdout);
|
||||
if(ms < 0)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
value_dev.FromDevice(value_host.data());
|
||||
index_dev.FromDevice(index_host.data());
|
||||
|
||||
bool rtn = true;
|
||||
if(validate)
|
||||
{
|
||||
ck_tile::HostTensor<WeightType> value_ref({tokens, topk}, {stride_output, 1});
|
||||
ck_tile::HostTensor<IndexType> index_ref({tokens, topk}, {stride_output, 1});
|
||||
|
||||
reference_topk_softmax<InputType, WeightType, IndexType>(
|
||||
x_host, value_ref, index_ref, topk);
|
||||
|
||||
auto [rtol, atol] = get_elimit<InputType>("");
|
||||
for(int i_t = 0; i_t < tokens; i_t++)
|
||||
{
|
||||
auto s_begin = std::vector<size_t>{static_cast<size_t>(i_t), static_cast<size_t>(0)};
|
||||
auto s_end =
|
||||
std::vector<size_t>{static_cast<size_t>(i_t + 1), static_cast<size_t>(topk)};
|
||||
auto s_value_host = value_host.slice(s_begin, s_end);
|
||||
auto s_value_ref = value_ref.slice(s_begin, s_end);
|
||||
rtn &= ck_tile::check_err(s_value_host,
|
||||
s_value_ref,
|
||||
std::string("[") + std::to_string(i_t) +
|
||||
std::string("] Value Error:"),
|
||||
rtol,
|
||||
atol);
|
||||
auto s_index_host = index_host.slice(s_begin, s_end);
|
||||
auto s_index_ref = index_ref.slice(s_begin, s_end);
|
||||
rtn &= ck_tile::check_err(s_index_host,
|
||||
s_index_ref,
|
||||
std::string("[") + std::to_string(i_t) +
|
||||
std::string("] Index Error:"),
|
||||
rtol,
|
||||
atol);
|
||||
}
|
||||
}
|
||||
|
||||
printf("valid:%s\n", rtn ? "y" : "n");
|
||||
fflush(stdout);
|
||||
return rtn;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int run_gemm_combinations(std::string const& data_type)
|
||||
{
|
||||
char bufs[7][64];
|
||||
char* argv[7] = {bufs[0], bufs[1], bufs[2], bufs[3], bufs[4], bufs[5], bufs[6]};
|
||||
std::vector<std::vector<std::string>> params = {
|
||||
{"-t=80", "-e=17"},
|
||||
{"-t=111", "-e=117"},
|
||||
{"-t=1000", "-e=55"},
|
||||
{"-t=99", "-e=180"},
|
||||
{"-t=175", "-e=64", "-k=8"},
|
||||
{"-t=65", "-e=8", "-k=2"},
|
||||
{"-t=1", "-e=25"},
|
||||
{"-t=31", "-e=19", "-k=15"},
|
||||
{"-t=81", "-e=37", "-k=7"},
|
||||
{"-t=199", "-e=128", "-k=13"},
|
||||
{"-t=23", "-e=1", "-k=1"},
|
||||
{"-t=127", "-e=99", "-k=19", "-st_i=233", "-st_o=31"},
|
||||
{"-t=71", "-e=11", "-k=11", "-st_i=30", "-st_o=12"},
|
||||
{"-t=1", "-e=1", "-k=1"},
|
||||
{"-t=99", "-e=2", "-k=1", "-st_i=11", "-st_o=5"},
|
||||
{"-t=333", "-e=99", "-k=13", "-st_i=191", "-st_o=17"}};
|
||||
|
||||
bool result = true;
|
||||
std::string pr_i = "-pr_i=" + data_type;
|
||||
strncpy(bufs[0], "test_topk_softmax_bf16", 64);
|
||||
strncpy(bufs[1], pr_i.c_str(), 64);
|
||||
for(size_t i = 0; i < params.size(); i++)
|
||||
{
|
||||
for(size_t j = 0; j < params[i].size(); j++)
|
||||
{
|
||||
strncpy(bufs[j + 2], params[i][j].c_str(), 64);
|
||||
}
|
||||
int argc = params[i].size() + 2;
|
||||
|
||||
auto [good_args, args] = create_args(argc, argv);
|
||||
if(!good_args)
|
||||
{
|
||||
result = false;
|
||||
}
|
||||
result = test_topk_softmax<T, float, ck_tile::index_t>(args) && result;
|
||||
}
|
||||
return result ? 0 : -1;
|
||||
}
|
||||
96
test/ck_tile/topk_softmax/test_topk_softmax_api.cpp
Normal file
96
test/ck_tile/topk_softmax/test_topk_softmax_api.cpp
Normal file
@@ -0,0 +1,96 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_topk_softmax_api.hpp"
|
||||
|
||||
#define TOPK_SOFTMAX_DISPATCH(experts_) \
|
||||
constexpr ck_tile::index_t ts_experts = experts_; \
|
||||
using ts_problem = ck_tile:: \
|
||||
TopkSoftmaxWarpPerRowProblem<ts_input_type, ts_weight_type, ts_index_type, ts_experts>; \
|
||||
using ts_pipeline = ck_tile::TopkSoftmaxWarpPerRowPipeline<ts_problem>; \
|
||||
\
|
||||
using kernel = ck_tile::TopkSoftmaxKernel<ts_pipeline>; \
|
||||
\
|
||||
auto kargs = kernel::MakeKargs(a); \
|
||||
\
|
||||
const dim3 grids = kernel::GridSize(a); \
|
||||
constexpr dim3 blocks = kernel::BlockSize(); \
|
||||
\
|
||||
float ave_time = ck_tile::launch_kernel( \
|
||||
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs)); \
|
||||
\
|
||||
return ave_time;
|
||||
|
||||
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s)
|
||||
{
|
||||
if(t.input_type == "fp16" && t.weight_type == "fp32")
|
||||
{
|
||||
using ts_input_type = ck_tile::fp16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
#if 1
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192)
|
||||
}
|
||||
#else
|
||||
if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if(t.input_type == "bf16" && t.weight_type == "fp32")
|
||||
{
|
||||
#if 1
|
||||
using ts_input_type = ck_tile::bf16_t;
|
||||
using ts_weight_type = float;
|
||||
using ts_index_type = ck_tile::index_t;
|
||||
if(t.experts <= 8)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(8)
|
||||
}
|
||||
else if(t.experts <= 16)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(16)
|
||||
}
|
||||
else if(t.experts <= 32)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(32)
|
||||
}
|
||||
else if(t.experts <= 64)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(64)
|
||||
}
|
||||
else if(t.experts <= 128)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(128)
|
||||
}
|
||||
else if(t.experts <= 192)
|
||||
{
|
||||
TOPK_SOFTMAX_DISPATCH(192)
|
||||
}
|
||||
#endif
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
21
test/ck_tile/topk_softmax/test_topk_softmax_api.hpp
Normal file
21
test/ck_tile/topk_softmax/test_topk_softmax_api.hpp
Normal file
@@ -0,0 +1,21 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/topk_softmax.hpp"
|
||||
#include <string>
|
||||
|
||||
struct topk_softmax_trait
|
||||
{
|
||||
std::string input_type;
|
||||
std::string weight_type; // currently always float
|
||||
int experts;
|
||||
};
|
||||
|
||||
struct topk_softmax_kargs : public ck_tile::TopkSoftmaxHostArgs
|
||||
{
|
||||
};
|
||||
|
||||
float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_config s);
|
||||
6
test/ck_tile/topk_softmax/test_topk_softmax_bf16.cpp
Normal file
6
test/ck_tile/topk_softmax/test_topk_softmax_bf16.cpp
Normal file
@@ -0,0 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_topk_softmax.hpp"
|
||||
|
||||
int main() { return run_gemm_combinations<ck_tile::bf16_t>("bf16"); }
|
||||
6
test/ck_tile/topk_softmax/test_topk_softmax_fp16.cpp
Normal file
6
test/ck_tile/topk_softmax/test_topk_softmax_fp16.cpp
Normal file
@@ -0,0 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "test_topk_softmax.hpp"
|
||||
|
||||
int main() { return run_gemm_combinations<ck_tile::fp16_t>("fp16"); }
|
||||
Reference in New Issue
Block a user