Tests for CK Tile Batched Transpose and Smoothquant (#2453)

* Create tests for ck tile batched transpose using example

* Create ck tile tests for smoothquant using examples

* fix precision input strings and convert batched transpose to regression tests

* Code cleanup and fix asserts

* add missing licenses

* update copyright and licensing in files

* Update smoothquant tests to use example's smoothquant.cpp

* Add custom target for batched transpose tests

* Add missing new lines at end of files for CMakelists

* fix typo in batched transpose CMakeList target_compile_options

---------

Co-authored-by: root <root@ctr-ubbsmc16.amd.com>
This commit is contained in:
Emily Martins
2025-07-17 09:53:34 -06:00
committed by GitHub
parent 7fc000d7b3
commit c08986b026
36 changed files with 1391 additions and 0 deletions

View File

@@ -37,6 +37,9 @@ set(REGRESSION_TESTS
test_grouped_convnd_bwd_data_xdl
test_conv_tensor_rearrange
test_gemm_mx
test_ck_tile_batched_transpose_fp8
test_ck_tile_batched_transpose_fp16
test_ck_tile_batched_transpose_bf16
)
function(add_test_executable TEST_NAME)

View File

@@ -6,3 +6,5 @@ add_subdirectory(grouped_gemm)
add_subdirectory(gemm_multi_d)
add_subdirectory(data_type)
add_subdirectory(slice_tile)
add_subdirectory(batched_transpose)
add_subdirectory(smoothquant)

View File

@@ -0,0 +1,33 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
function (add_batched_transpose_test TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")
add_test_executable(${TARGET_NAME} ${MAIN_SRC} batched_transpose_api.cpp)
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
# list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker)
target_compile_options(${TARGET_NAME} PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS})
endfunction(add_batched_transpose_test TARGET_NAME MAIN_SRC)
set(CUSTOM_TARGET_NAME test_ck_tile_batched_transpose)
add_custom_target(${CUSTOM_TARGET_NAME})
add_batched_transpose_test(test_ck_tile_batched_transpose_fp16 batched_transpose_fp16.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_fp16)
add_batched_transpose_test(test_ck_tile_batched_transpose_fp8 batched_transpose_fp8.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_fp8)
add_batched_transpose_test(test_ck_tile_batched_transpose_bf16 batched_transpose_bf16.cpp)
add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_bf16)
else()
message(DEBUG "Skipping ck_tile batched_transpose tests for current target")
endif()

View File

@@ -0,0 +1,25 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/reduce.hpp"
#include "ck_tile/ops/batched_transpose.hpp"
#include <vector>
#include <string>
#pragma once
struct batched_transpose_trait
{
std::string type;
std::string layout;
};
struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs
{
};
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s);

View File

@@ -0,0 +1,283 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <vector>
#include <iostream>
#include <numeric>
#include <cassert>
#include <cstdlib>
#include <iostream>
#include <time.h>
#include <unordered_set>
#include "batched_transpose.hpp"
// different threshold for different dtype
template <typename DataType>
auto get_elimit(std::string /*init_method*/)
{
double rtol = 1e-3;
double atol = 1e-3;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>(std::string /*init_method*/)
{
double rtol = 1e-2;
double atol = 1e-2;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::fp8_t>(std::string init_method)
{
if(init_method == "ui" || init_method == "ni")
{
unsigned max_rounding_point_distance = 0;
double atol = 2e-3;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
else
{
unsigned max_rounding_point_distance = 1;
double atol = 0.0625;
return ck_tile::make_tuple(max_rounding_point_distance, atol);
}
}
auto create_args(int argc, char* argv[], int index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("v", "1", "whether do CPU validation or not")
.insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)")
.insert("N", "1", "input batch size. ")
.insert("C", "64", "input channel size.")
.insert("H", "18", "input height size.")
.insert("W", "64", "input width size. ")
.insert("layout_in", "NCHW", "input tensor data layout - NCHW by default")
.insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ")
.insert("warmup", "50", "number of iterations before benchmark the kernel")
.insert("repeat", "100", "number of iterations to benchmark the kernel")
.insert("seed", "-1", "seed to be used, -1 means random every time")
.insert("kname", "0", "t to 1 will print kernel name");
bool result = arg_parser.parse(argc, argv, index);
return std::make_tuple(result, arg_parser);
}
template <typename Type>
bool run_batched_transpose(ck_tile::ArgParser args)
{
int validate = args.get_int("v");
std::string prec = args.get_str("pr");
int N = args.get_int("N");
int C = args.get_int("C");
int H = args.get_int("H");
int W = args.get_int("W");
int n_warmup = args.get_int("warmup");
int n_repeat = args.get_int("repeat");
std::string layout_in = args.get_str("layout_in");
std::string layout_out = args.get_str("layout_out");
int seed = args.get_int("seed");
int dim_in[4], dim_out[4];
int stride_dim_in[4], stride_dim_out[4];
bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC";
bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW";
assert(nchw2nhwc != nhwc2nchw);
(void)nhwc2nchw;
dim_in[0] = N;
dim_in[1] = nchw2nhwc ? C : H;
dim_in[2] = nchw2nhwc ? H : W;
dim_in[3] = nchw2nhwc ? W : C;
dim_out[0] = N;
dim_out[1] = nchw2nhwc ? H : C;
dim_out[2] = nchw2nhwc ? W : H;
dim_out[3] = nchw2nhwc ? C : W;
stride_dim_in[0] = C * H * W;
stride_dim_in[1] = nchw2nhwc ? H * W : C * W;
stride_dim_in[2] = nchw2nhwc ? W : C;
stride_dim_in[3] = 1;
stride_dim_out[0] = C * H * W;
stride_dim_out[1] = nchw2nhwc ? C * W : H * W;
stride_dim_out[2] = nchw2nhwc ? C : W;
stride_dim_out[3] = 1;
if(seed < 0)
{
seed = std::time(nullptr);
}
ck_tile::HostTensor<Type> x_host(
{dim_in[0], dim_in[1], dim_in[2], dim_in[3]},
{stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]});
ck_tile::HostTensor<Type> y_host(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::FillUniformDistribution<Type>{-.5f, .5f}(x_host);
ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes());
x_dev.ToDevice(x_host.data());
auto trait = batched_transpose_trait{prec, layout_in};
uint32_t height = nchw2nhwc ? C : H * W;
uint32_t width = nchw2nhwc ? H * W : C;
batched_transpose_kargs karg = [&]() {
batched_transpose_kargs a_;
a_.p_input = x_dev.GetDeviceBuffer();
a_.p_output = y_dev.GetDeviceBuffer();
a_.batch = N;
a_.height = height;
a_.width = width;
return a_;
}();
ck_tile::stream_config sc{nullptr, true, n_warmup, n_repeat};
auto ms = batched_transpose(trait, karg, sc);
std::size_t num_operations = N * C * H * (W - 1);
std::size_t num_bytes = N * C * H * W * sizeof(Type);
float ave_time = ms * 1E-3;
float gb_per_sec = num_bytes / ms * 1.E-6;
float tflops = static_cast<float>(num_operations) / ms * 1.E-6;
std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H
<< ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out
<< " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops"
<< gb_per_sec << " GB/s, " << std::endl;
printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n",
prec.c_str(),
N,
C,
H,
W,
layout_in.c_str(),
ms);
if(ms < 0)
printf("------------------------------------not "
"supported-------------------------------------\n");
fflush(stdout);
if(ms < 0)
{
return false;
}
y_dev.FromDevice(y_host.data());
bool rtn = true;
if(validate)
{
// this host buffer will not copy to GPU, so no need use stride
ck_tile::HostTensor<Type> y_ref(
{dim_out[0], dim_out[1], dim_out[2], dim_out[3]},
{stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]});
ck_tile::reference_batched_transpose<Type>(x_host, y_ref, layout_in, layout_out);
auto [rtol, atol] = get_elimit<Type>("");
rtn &= ck_tile::check_err(
y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol);
}
printf("-----------------------------------------------------------------------valid:%s--------"
"--------------------------------------------------------------------\n",
rtn ? "y" : "n");
fflush(stdout);
return rtn;
}
template <typename PrecType>
bool run_test_case(int argc, char** argv)
{
auto [result, args] = create_args(argc, argv);
if(!result)
return false;
return run_batched_transpose<PrecType>(args);
}
template <typename PrecType>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
constexpr int num_args = 7;
char* argv[num_args];
assert(test_cases[test_idx].size() == num_args &&
"invalid number of arguments in test case");
for(std::size_t idx = 0; idx < test_cases[test_idx].size(); ++idx)
{
argv[idx] = test_cases[test_idx][idx].data();
}
valid = valid && run_test_case<PrecType>(num_args, argv);
if(!valid)
break;
}
return valid;
}
std::vector<std::vector<std::string>> generate_test_cases(const std::string prec)
{
return {
{"-pr=" + prec, "-N=1", "-C=32", "-H=1", "-W=32", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=1", "-C=64", "-H=1", "-W=64", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=2", "-C=12", "-H=1", "-W=32", "-layout_in=NHWC", "-layout_out=NCHW"},
{"-pr=" + prec, "-N=3", "-C=1334", "-H=1", "-W=37", "-layout_in=NHWC", "-layout_out=NCHW"},
{"-pr=" + prec, "-N=4", "-C=27", "-H=1", "-W=32", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=5", "-C=1234", "-H=1", "-W=12", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=1", "-C=1", "-H=1", "-W=1", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=1", "-C=1", "-H=1", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"},
{"-pr=" + prec,
"-N=128",
"-C=1024",
"-H=64",
"-W=64",
"-layout_in=NCHW",
"-layout_out=NHWC"},
{"-pr=" + prec,
"-N=128",
"-C=1024",
"-H=64",
"-W=64",
"-layout_in=NHWC",
"-layout_out=NCHW"},
{"-pr=" + prec, "-N=16", "-C=64", "-H=32", "-W=128", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=16", "-C=64", "-H=128", "-W=32", "-layout_in=NHWC", "-layout_out=NCHW"},
{"-pr=" + prec, "-N=1", "-C=2048", "-H=1", "-W=1", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=1", "-C=2048", "-H=1", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"},
{"-pr=" + prec,
"-N=1",
"-C=1",
"-H=1024",
"-W=1024",
"-layout_in=NCHW",
"-layout_out=NHWC"},
{"-pr=" + prec,
"-N=1",
"-C=1",
"-H=1024",
"-W=1024",
"-layout_in=NHWC",
"-layout_out=NCHW"},
{"-pr=" + prec, "-N=8", "-C=16", "-H=8", "-W=16", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=8", "-C=16", "-H=8", "-W=16", "-layout_in=NHWC", "-layout_out=NCHW"},
{"-pr=" + prec, "-N=1", "-C=64", "-H=1", "-W=1024", "-layout_in=NCHW", "-layout_out=NHWC"},
{"-pr=" + prec, "-N=1", "-C=64", "-H=1024", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"}};
}

View File

@@ -0,0 +1,113 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "batched_transpose.hpp"
template <typename ts_type,
ck_tile::index_t block_x,
ck_tile::index_t block_y,
ck_tile::index_t warp_x,
ck_tile::index_t warp_y,
ck_tile::index_t thread_x,
ck_tile::index_t thread_y,
bool kPadM,
bool kPadN>
float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s)
{
uint32_t dim_stride = a.height * a.width;
a.dim_stride = dim_stride;
a.dim_block_h = block_y;
a.dim_block_w = block_x;
using block_tile = ck_tile::sequence<block_x, block_y>;
using warp_tile = ck_tile::sequence<warp_x, warp_y>;
using thread_tile = ck_tile::sequence<thread_x, thread_y>;
using ts_problem =
ck_tile::BatchedTransposeProblem<ts_type, block_tile, warp_tile, thread_tile, kPadM, kPadN>;
using ts_pipeline = ck_tile::BatchedTransposePipeline<ts_problem>;
using kernel = ck_tile::BatchedTransposeKernel<ts_pipeline>;
auto kargs = kernel::MakeKargs(a);
const dim3 grids = kernel::GridSize(a);
constexpr dim3 blocks = kernel::BlockSize();
printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z);
printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z);
printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n",
kargs.batch,
kargs.height,
kargs.width,
kargs.dim_stride);
printf("Launching Kernel...\n");
float ave_time = ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, 1>(kernel{}, grids, blocks, 0, kargs));
printf("Kernel finished...\n");
return ave_time;
}
// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y
#define FOREACH_TRANSPOSE_PARAM(F) \
F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \
F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \
F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \
F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \
F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \
F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false)
// Macro that defines one static function per line
#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \
static float \
transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \
batched_transpose_kargs& a, ck_tile::stream_config& s) \
{ \
return batched_transpose_dispatch<REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN>(a, s); \
}
FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN)
float batched_transpose(batched_transpose_trait t,
batched_transpose_kargs a,
ck_tile::stream_config s)
{
if(t.type == "fp8")
{
if(a.height % 64 == 0 && a.width % 64 == 0)
{
return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s);
}
else
{
return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s);
}
}
else if(t.type == "fp16")
{
if(a.height % 64 == 0 && a.width % 64 == 0)
{
return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s);
}
else
{
return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s);
}
}
else if(t.type == "bf16")
{
if(a.height % 64 == 0 && a.width % 64 == 0)
{
return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s);
}
else
{
return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s);
}
}
return -1;
}

View File

@@ -0,0 +1,10 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "batched_transpose.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = generate_test_cases("bf16");
return !run_test_cases<ck_tile::bf16_t>(test_cases);
}

View File

@@ -0,0 +1,10 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "batched_transpose.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = generate_test_cases("fp16");
return !run_test_cases<ck_tile::fp16_t>(test_cases);
}

View File

@@ -0,0 +1,10 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "batched_transpose.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = generate_test_cases("fp8");
return !run_test_cases<ck_tile::fp8_t>(test_cases);
}

View File

@@ -0,0 +1,28 @@
# Currently ck_tile is only built on gfx9
if(GPU_TARGETS MATCHES "gfx9")
function (add_smoothquant_test TARGET_NAME MAIN_SRC)
message(DEBUG "adding ${TARGET_NAME}")
add_test_executable(${TARGET_NAME} ${MAIN_SRC})
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
foreach(source IN LISTS ARGN)
list(APPEND INSTANCE_SRCS ${source})
endforeach()
target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS})
set(COMPILE_OPTIONS)
# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations
list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal)
target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS})
endfunction(add_smoothquant_test TARGET_NAME MAIN_SRC)
file(GLOB INSTANCE_SRCS instances/*.cpp)
add_smoothquant_test(test_ck_tile_smoothquant_fp16 smoothquant_fp16.cpp ${INSTANCE_SRCS})
add_smoothquant_test(test_ck_tile_smoothquant_bf16 smoothquant_bf16.cpp ${INSTANCE_SRCS})
else()
message(DEBUG "Skipping ck_tile smoothquant tests for current target")
endif()

View File

@@ -0,0 +1,21 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
#if 0
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 16, 4, 64, 1, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 4, true, false>>(const S&, A);
#endif
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,12 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,13 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,13 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 128, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 256, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 1, 256, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,13 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,13 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,12 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 4, 4, 64, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 8, 4, 64, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 1, true , false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 1, 4, 64, 2, true , false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 2, 4, 64, 1, true , false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 3, 4, 64, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 6, 4, 64, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::bf16_t, 1, 12, 4, 64, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,21 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
#if 0
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 8, true ,false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 4, true ,false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 2, true ,false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 16, 4, 64, 1, true ,false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 4, true ,false>>(const S&, A);
#endif
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 2, 128, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 2, 128, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 2, 128, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,12 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 2, 128, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,13 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 1, 256, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 1, 256, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,13 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 128, 8,true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 256, 4,true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 1, 256, 2,true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 1, 1024, 1,true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,13 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,13 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 256, 8, true, true>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 256, 4, true, true>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 1, 1024, 2, true, true>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 1, 1024, 1, true, true>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,12 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 8, true , false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 4, true , false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 4, 4, 64, 2, true , false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 8, 4, 64, 1, true , false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 1, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 1, 4, 64, 2, true, false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 2, 4, 64, 1, true, false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant_instance_common.hpp"
// clang-format off
// rm rn tm tn vn pd 2p
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 3, 4, 64, 4, true , false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 6, 4, 64, 2, true , false>>(const S&, A);
template float smoothquant_<trait_<ck_tile::fp16_t, 1, 12, 4, 64, 1, true , false>>(const S&, A);
// clang-format on

View File

@@ -0,0 +1,143 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <ck_tile/core.hpp>
#include "smoothquant.hpp"
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
using trait_ = smoothquant_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kTwoPass_>;
template <typename data_type>
float smoothquant_dispatch(smoothquant_traits /*t*/,
smoothquant_args a,
const ck_tile::stream_config& s)
{
float r = -1;
// clang-format off
// rm rn tm tn vn pd 2p
if(a.n <= 64) {
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 1, true, false>>(s, a);
}
else if(a.n <= 128) {
if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 1, true, false>>(s, a);
}
else if(a.n <= 256) {
if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 4, true, false>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 1, true, false>>(s, a);
}
else if(a.n <= 512) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 4, 64, 8, true, false>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 4, 64, 4, true, false>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 4, 64, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 8, 4, 64, 1, true, false>>(s, a);
}
else if(a.n <= 768) {
if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 4, true, false>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 6, 4, 64, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1,12, 4, 64, 1, true, false>>(s, a);
}
else if(a.n <= 1024) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 2, 128, 8, true, false>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 2, 128, 4, true, false>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 2, 128, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 1, true, false>>(s, a);
}
else if(a.n <= 1536) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 4, 64, 8, true, false>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 2, 128, 4, true, false>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 1, true, false>>(s, a);
}
else if(a.n <= 2048) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 1, 1, 256, 8, true, false>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 4, true, false>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 8, 1, 256, 1, true, false>>(s, a);
}
else if(a.n <= 3072) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 128, 8, true, false>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 3, 1, 256, 4, true, false>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 6, 1, 256, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 3, 1, 1024, 1, true, false>>(s, a);
}
else if(a.n <= 4096) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, false>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, false>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, false>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, false>>(s, a);
}
else if(a.n > 4096) {
if (a.n % 8 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 256, 8, true, true>>(s, a);
else if (a.n % 4 == 0)
r = smoothquant_<trait_<data_type, 1, 4, 1, 256, 4, true, true>>(s, a);
else if (a.n % 2 == 0)
r = smoothquant_<trait_<data_type, 1, 2, 1, 1024, 2, true, true>>(s, a);
else
r = smoothquant_<trait_<data_type, 1, 4, 1, 1024, 1, true, true>>(s, a);
}
return r;
// clang-format on
}
float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::stream_config& s)
{
if(t.data_type.compare("fp16") == 0)
{
return smoothquant_dispatch<ck_tile::fp16_t>(t, a, s);
}
else if(t.data_type.compare("bf16") == 0)
{
return smoothquant_dispatch<ck_tile::bf16_t>(t, a, s);
}
else
throw std::runtime_error("Without supported instances!");
}

View File

@@ -0,0 +1,61 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include <ck_tile/core.hpp>
#include "smoothquant.hpp"
#include <iostream>
#pragma once
using S = ck_tile::stream_config;
using A = smoothquant_args;
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
using trait_ = smoothquant_traits_<DataType_,
Repeat_M_,
Repeat_N_,
ThreadPerBlock_M_,
ThreadPerBlock_N_,
Vector_N_,
kPadN_,
kTwoPass_>;
template <typename Traits_>
float smoothquant_(const S& s, A a)
{
using DataType = typename Traits_::DataType;
using PipelineProblem = ck_tile::SmoothquantPipelineProblem<
typename SmoothquantTypeConfig<DataType>::XDataType,
typename SmoothquantTypeConfig<DataType>::SmoothScaleDataType,
typename SmoothquantTypeConfig<DataType>::ComputeDataType,
typename SmoothquantTypeConfig<DataType>::YScaleDataType,
typename SmoothquantTypeConfig<DataType>::QYDataType,
typename Traits_::Shape,
Traits_::kPadN,
Traits_::kTwoPass>;
using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass<PipelineProblem>;
using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass<PipelineProblem>;
using Pipeline = std::conditional_t<Traits_::kTwoPass, TwoPassPipeline, OnePassPipeline>;
using Kernel = ck_tile::Smoothquant<Pipeline>;
const dim3 grids = Kernel::GridSize(a);
constexpr dim3 blocks = Kernel::BlockSize();
constexpr ck_tile::index_t kBlockPerCu = 1;
auto kargs = Kernel::MakeKargs(a);
if(s.log_level_ > 0)
std::cout << ", " << Kernel::GetName() << std::flush;
return ck_tile::launch_kernel(
s, ck_tile::make_kernel<blocks.x, kBlockPerCu>(Kernel{}, grids, blocks, 0, kargs));
}

View File

@@ -0,0 +1,114 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/smoothquant.hpp"
#include <string>
template <typename DataType>
struct SmoothquantTypeConfig;
template <>
struct SmoothquantTypeConfig<ck_tile::half_t>
{
using XDataType = ck_tile::half_t;
using SmoothScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
template <>
struct SmoothquantTypeConfig<ck_tile::bf16_t>
{
using XDataType = ck_tile::bf16_t;
using SmoothScaleDataType = float;
using YScaleDataType = float;
using QYDataType = ck_tile::int8_t;
using ComputeDataType = float;
};
// runtime args
struct smoothquant_args : public ck_tile::SmoothquantHostArgs
{
};
// this is used to pattern-match internl kernel implementation, not to instantiate kernel
template <typename DataType_,
ck_tile::index_t Repeat_M_, // each thread repeat along M
ck_tile::index_t Repeat_N_, // each thread repeat along N
ck_tile::index_t ThreadPerBlock_M_, // num threads along M
ck_tile::index_t ThreadPerBlock_N_, // num threads along N
ck_tile::index_t Vector_N_, // vector size along N
bool kPadN_,
bool kTwoPass_>
struct smoothquant_traits_
{
using DataType = ck_tile::remove_cvref_t<DataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_);
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_;
static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_;
static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M;
static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_;
using BlockTile = ck_tile::sequence<Block_M, Block_N>;
using BlockWarps = ck_tile::sequence<BlockWarps_M, BlockWarps_N>;
using WarpTile = ck_tile::sequence<Warp_M, Warp_N>;
using Vector = ck_tile::sequence<1, Vector_N_>;
using Shape = ck_tile::Generic2dBlockShape<BlockTile, BlockWarps, WarpTile, Vector>;
static constexpr bool kPadN = kPadN_;
static constexpr bool kTwoPass = kTwoPass_;
};
template <typename Traits_>
float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a);
// This is the public API, will be generated by script
struct smoothquant_traits
{
std::string data_type;
};
float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&);

View File

@@ -0,0 +1,274 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/host.hpp"
#include "smoothquant.hpp"
#include <cstring>
// different threshold for different dtype
template <typename DataType>
auto get_elimit()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::bf16_t>()
{
double rtol = 1e-5;
double atol = 1e-5;
return ck_tile::make_tuple(rtol, atol);
}
template <>
auto get_elimit<ck_tile::int8_t>()
{
// due to rounding, int8 quantization might have 1 abs error
double rtol = 1;
double atol = 1;
return ck_tile::make_tuple(rtol, atol);
}
auto create_args(int argc, char* argv[], int index = 0)
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "3328", "m dimension")
.insert("n", "4096", "n dimension")
.insert("x_stride", "-1", "input stride per row, if -1 then equal to n")
.insert("y_stride", "-1", "output stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("kname", "1", "print kernel name or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "5", "cold iter")
.insert("repeat", "20", "hot iter");
bool result = arg_parser.parse(argc, argv, index);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
x_stride = n;
ck_tile::index_t y_stride = arg_parser.get_int("y_stride");
if(y_stride < 0)
y_stride = n;
std::string data_type = arg_parser.get_str("prec");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(x_stride >= n);
using TypeConfig = SmoothquantTypeConfig<DataType>;
using XDataType = typename TypeConfig::XDataType;
using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType;
using YScaleDataType = typename TypeConfig::YScaleDataType;
using QYDataType = typename TypeConfig::QYDataType;
using ComputeDataType = typename TypeConfig::ComputeDataType;
// host verify
ck_tile::HostTensor<XDataType> x_host({m, n}, {x_stride, 1});
ck_tile::HostTensor<SmoothScaleDataType> smscale_host({n});
ck_tile::HostTensor<YScaleDataType> yscale_host_ref({m}, {1});
ck_tile::HostTensor<YScaleDataType> yscale_host_dev({m}, {1});
ck_tile::HostTensor<QYDataType> qy_host_ref({m, n}, {y_stride, 1});
ck_tile::HostTensor<QYDataType> qy_host_dev({m, n}, {y_stride, 1});
ck_tile::FillUniformDistribution<XDataType>{-.5f, .5f}(x_host);
ck_tile::FillUniformDistribution<SmoothScaleDataType>{1e-3, .5f}(smscale_host);
ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes());
ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes());
ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes());
x_buf.ToDevice(x_host.data());
smscale_buf.ToDevice(smscale_host.data());
std::cout << "[" << data_type << "]"
<< " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride
<< std::flush;
smoothquant_traits traits{data_type};
smoothquant_args args{x_buf.GetDeviceBuffer(),
smscale_buf.GetDeviceBuffer(),
yscale_buf.GetDeviceBuffer(),
qy_buf.GetDeviceBuffer(),
m,
n,
x_stride,
y_stride};
float ave_time = smoothquant(
traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat});
std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n +
sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n;
float gb_per_sec = num_byte / 1.E6 / ave_time;
std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush;
bool pass = true;
if(do_validation)
{
using YDataType = ComputeDataType;
ck_tile::HostTensor<ComputeDataType> y_host({m, n}, {y_stride, 1});
// smooth outlier
{
auto f = [&](auto n_) {
auto v_smscale = ck_tile::type_convert<ComputeDataType>(smscale_host(n_));
for(int m_ = 0; m_ < m; ++m_)
{
auto v_x = ck_tile::type_convert<ComputeDataType>(x_host(m_, n_));
y_host(m_, n_) = v_x * v_smscale;
}
};
ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())(
std::thread::hardware_concurrency());
}
// yscale
{
ck_tile::HostTensor<YDataType> y_rowwise_amax_host({m});
using ReduceAmax = ck_tile::ReduceOp::AbsMax;
ck_tile::reference_reduce<ComputeDataType, ComputeDataType, YDataType>(
y_host, y_rowwise_amax_host, ReduceAmax{});
auto op = [](const auto& v0) {
return v0 /
ck_tile::type_convert<ComputeDataType>(ck_tile::numeric<QYDataType>::max());
};
ck_tile::reference_unary_elementwise<YDataType, YScaleDataType, ComputeDataType>(
y_rowwise_amax_host, yscale_host_ref, op);
yscale_buf.FromDevice(yscale_host_dev.mData.data());
auto [rtol, atol] = get_elimit<YScaleDataType>();
pass &= ck_tile::check_err(yscale_host_dev,
yscale_host_ref,
std::string("yscale Error: Incorrect results!"),
rtol,
atol);
}
// rowwise quantization
{
ck_tile::reference_rowwise_quantization2d<YDataType, YScaleDataType, QYDataType>(
y_host, yscale_host_ref, qy_host_ref);
qy_buf.FromDevice(qy_host_dev.data());
auto [rtol, atol] = get_elimit<QYDataType>();
if(y_stride == n)
{
pass = ck_tile::check_err(qy_host_dev,
qy_host_ref,
std::string("qy Error: Incorrect results!"),
rtol,
atol);
}
else
{
for(int i_r = 0; i_r < m; i_r++)
{
std::vector<QYDataType> qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride,
qy_host_dev.begin() + i_r * y_stride +
n);
std::vector<QYDataType> qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride,
qy_host_ref.begin() + i_r * y_stride +
n);
pass &= ck_tile::check_err(qy_host_dev_row,
qy_host_ref_row,
std::string("qy[") + std::to_string(i_r) +
std::string("] Error: Incorrect results!"),
rtol,
atol);
}
}
}
std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl;
}
return pass;
}
std::vector<std::vector<std::string>> create_test_cases(const std::string prec)
{
return {{"-prec=" + prec, "-m=99", "-n=13", "-x_stride=-1"},
{"-prec=" + prec, "-m=17", "-n=16", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=100", "-x_stride=-1"},
{"-prec=" + prec, "-m=4", "-n=128", "-x_stride=-1"},
{"-prec=" + prec, "-m=80", "-n=127", "-x_stride=-1"},
{"-prec=" + prec, "-m=22", "-n=255", "-x_stride=256"},
{"-prec=" + prec, "-m=7", "-n=599", "-x_stride=-1"},
{"-prec=" + prec, "-m=19", "-n=512", "-x_stride=-1"},
{"-prec=" + prec, "-m=33", "-n=313", "-x_stride=1000"},
{"-prec=" + prec, "-m=11", "-n=510", "-x_stride=-1"},
{"-prec=" + prec, "-m=171", "-n=676", "-x_stride=818"},
{"-prec=" + prec, "-m=91", "-n=636", "-x_stride=-1"},
{"-prec=" + prec, "-m=12", "-n=768", "-x_stride=800"},
{"-prec=" + prec, "-m=100", "-n=766", "-x_stride=812"},
{"-prec=" + prec, "-m=31", "-n=1024", "-x_stride=-1"},
{"-prec=" + prec, "-m=64", "-n=1000", "-x_stride=1004"},
{"-prec=" + prec, "-m=8", "-n=1501", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=1826", "-x_stride=-1"},
{"-prec=" + prec, "-m=5", "-n=2040", "-x_stride=-1"},
{"-prec=" + prec, "-m=7", "-n=2734", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=3182", "-x_stride=-1"},
{"-prec=" + prec, "-m=9", "-n=4096", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=8192", "-x_stride=-1"},
{"-prec=" + prec, "-m=1", "-n=10547", "-x_stride=-1"},
{"-prec=" + prec, "-m=3", "-n=17134", "-x_stride=-1"}};
}
template <typename DataType>
bool run_test_case(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return false;
return run<DataType>(arg_parser);
}
template <typename DataType>
bool run_test_cases(std::vector<std::vector<std::string>>& test_cases)
{
bool valid = true;
constexpr int num_args = 4;
char* argv[num_args];
for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx)
{
assert(test_cases[test_idx].size() == num_args &&
"invalid number of arguments in test case");
for(std::size_t idx = 0; idx < num_args; ++idx)
{
argv[idx] = test_cases[test_idx][idx].data();
}
valid = valid && run_test_case<DataType>(num_args, argv);
if(!valid)
break;
}
return valid;
}

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("bf16");
return !run_test_cases<ck_tile::bf16_t>(test_cases);
}

View File

@@ -0,0 +1,11 @@
// Copyright © Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "smoothquant.inc"
int main()
{
std::vector<std::vector<std::string>> test_cases = create_test_cases("fp16");
return !run_test_cases<ck_tile::half_t>(test_cases);
}