diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 1be7c88c2e..c738eab802 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -37,6 +37,9 @@ set(REGRESSION_TESTS test_grouped_convnd_bwd_data_xdl test_conv_tensor_rearrange test_gemm_mx + test_ck_tile_batched_transpose_fp8 + test_ck_tile_batched_transpose_fp16 + test_ck_tile_batched_transpose_bf16 ) function(add_test_executable TEST_NAME) diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index cc933012ac..5c0f3fb076 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -6,3 +6,5 @@ add_subdirectory(grouped_gemm) add_subdirectory(gemm_multi_d) add_subdirectory(data_type) add_subdirectory(slice_tile) +add_subdirectory(batched_transpose) +add_subdirectory(smoothquant) diff --git a/test/ck_tile/batched_transpose/CMakeLists.txt b/test/ck_tile/batched_transpose/CMakeLists.txt new file mode 100644 index 0000000000..ac8e3dac49 --- /dev/null +++ b/test/ck_tile/batched_transpose/CMakeLists.txt @@ -0,0 +1,33 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + + function (add_batched_transpose_test TARGET_NAME MAIN_SRC) + message(DEBUG "adding ${TARGET_NAME}") + + add_test_executable(${TARGET_NAME} ${MAIN_SRC} batched_transpose_api.cpp) + target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + # list(APPEND EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS -v --save-temps -Wno-gnu-line-marker) + target_compile_options(${TARGET_NAME} PRIVATE ${EXAMPLE_BATCHED_TRANSPOSE_COMPILE_OPTIONS}) + + endfunction(add_batched_transpose_test TARGET_NAME MAIN_SRC) + + set(CUSTOM_TARGET_NAME test_ck_tile_batched_transpose) + + add_custom_target(${CUSTOM_TARGET_NAME}) + + add_batched_transpose_test(test_ck_tile_batched_transpose_fp16 batched_transpose_fp16.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_fp16) + + add_batched_transpose_test(test_ck_tile_batched_transpose_fp8 batched_transpose_fp8.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_fp8) + + add_batched_transpose_test(test_ck_tile_batched_transpose_bf16 batched_transpose_bf16.cpp) + add_dependencies(${CUSTOM_TARGET_NAME} test_ck_tile_batched_transpose_bf16) + + +else() + message(DEBUG "Skipping ck_tile batched_transpose tests for current target") +endif() diff --git a/test/ck_tile/batched_transpose/batched_transpose.hpp b/test/ck_tile/batched_transpose/batched_transpose.hpp new file mode 100644 index 0000000000..bd1abb1191 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose.hpp @@ -0,0 +1,25 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "ck_tile/core.hpp" +#include "ck_tile/host.hpp" +#include "ck_tile/ops/reduce.hpp" +#include "ck_tile/ops/batched_transpose.hpp" + +#include +#include + +#pragma once + +struct batched_transpose_trait +{ + std::string type; + std::string layout; +}; + +struct batched_transpose_kargs : public ck_tile::BatchedTransposeHostArgs +{ +}; + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s); diff --git a/test/ck_tile/batched_transpose/batched_transpose.inc b/test/ck_tile/batched_transpose/batched_transpose.inc new file mode 100644 index 0000000000..30084f5664 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose.inc @@ -0,0 +1,283 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "batched_transpose.hpp" + +// different threshold for different dtype +template +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-3; + double atol = 1e-3; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string /*init_method*/) +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit(std::string init_method) +{ + if(init_method == "ui" || init_method == "ni") + { + unsigned max_rounding_point_distance = 0; + double atol = 2e-3; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } + else + { + unsigned max_rounding_point_distance = 1; + double atol = 0.0625; + return ck_tile::make_tuple(max_rounding_point_distance, atol); + } +} + +auto create_args(int argc, char* argv[], int index = 0) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("v", "1", "whether do CPU validation or not") + .insert("pr", "fp16", "input data type. fp16/fp32 (representing 8/16/32 bit data)") + .insert("N", "1", "input batch size. ") + .insert("C", "64", "input channel size.") + .insert("H", "18", "input height size.") + .insert("W", "64", "input width size. ") + .insert("layout_in", "NCHW", "input tensor data layout - NCHW by default") + .insert("layout_out", "NHWC", "output tensor data layout - NHWC by default ") + .insert("warmup", "50", "number of iterations before benchmark the kernel") + .insert("repeat", "100", "number of iterations to benchmark the kernel") + .insert("seed", "-1", "seed to be used, -1 means random every time") + .insert("kname", "0", "t to 1 will print kernel name"); + + bool result = arg_parser.parse(argc, argv, index); + return std::make_tuple(result, arg_parser); +} + +template +bool run_batched_transpose(ck_tile::ArgParser args) +{ + int validate = args.get_int("v"); + std::string prec = args.get_str("pr"); + int N = args.get_int("N"); + int C = args.get_int("C"); + int H = args.get_int("H"); + int W = args.get_int("W"); + int n_warmup = args.get_int("warmup"); + int n_repeat = args.get_int("repeat"); + std::string layout_in = args.get_str("layout_in"); + std::string layout_out = args.get_str("layout_out"); + int seed = args.get_int("seed"); + + int dim_in[4], dim_out[4]; + int stride_dim_in[4], stride_dim_out[4]; + bool nchw2nhwc = layout_in == "NCHW" && layout_out == "NHWC"; + bool nhwc2nchw = layout_in == "NHWC" && layout_out == "NCHW"; + assert(nchw2nhwc != nhwc2nchw); + (void)nhwc2nchw; + + dim_in[0] = N; + dim_in[1] = nchw2nhwc ? C : H; + dim_in[2] = nchw2nhwc ? H : W; + dim_in[3] = nchw2nhwc ? W : C; + dim_out[0] = N; + dim_out[1] = nchw2nhwc ? H : C; + dim_out[2] = nchw2nhwc ? W : H; + dim_out[3] = nchw2nhwc ? C : W; + stride_dim_in[0] = C * H * W; + stride_dim_in[1] = nchw2nhwc ? H * W : C * W; + stride_dim_in[2] = nchw2nhwc ? W : C; + stride_dim_in[3] = 1; + stride_dim_out[0] = C * H * W; + stride_dim_out[1] = nchw2nhwc ? C * W : H * W; + stride_dim_out[2] = nchw2nhwc ? C : W; + stride_dim_out[3] = 1; + + if(seed < 0) + { + seed = std::time(nullptr); + } + + ck_tile::HostTensor x_host( + {dim_in[0], dim_in[1], dim_in[2], dim_in[3]}, + {stride_dim_in[0], stride_dim_in[1], stride_dim_in[2], stride_dim_in[3]}); + ck_tile::HostTensor y_host( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + + ck_tile::DeviceMem x_dev(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_dev(y_host.get_element_space_size_in_bytes()); + + x_dev.ToDevice(x_host.data()); + + auto trait = batched_transpose_trait{prec, layout_in}; + + uint32_t height = nchw2nhwc ? C : H * W; + uint32_t width = nchw2nhwc ? H * W : C; + + batched_transpose_kargs karg = [&]() { + batched_transpose_kargs a_; + a_.p_input = x_dev.GetDeviceBuffer(); + a_.p_output = y_dev.GetDeviceBuffer(); + a_.batch = N; + a_.height = height; + a_.width = width; + return a_; + }(); + + ck_tile::stream_config sc{nullptr, true, n_warmup, n_repeat}; + + auto ms = batched_transpose(trait, karg, sc); + + std::size_t num_operations = N * C * H * (W - 1); + std::size_t num_bytes = N * C * H * W * sizeof(Type); + + float ave_time = ms * 1E-3; + float gb_per_sec = num_bytes / ms * 1.E-6; + float tflops = static_cast(num_operations) / ms * 1.E-6; + + std::cout << "Run Batched Transpose kernel with N=" << N << ", C=" << C << ", H=" << H + << ", W=" << W << ", layout_in=" << layout_in << ", layout_out=" << layout_out + << " : " << ms << " ms (" << ave_time << " ave_time), " << tflops << " TFlops" + << gb_per_sec << " GB/s, " << std::endl; + + printf("[%s]N:%d, C:%d, H:%d, W:%d, layout_in:%s, %f\n", + prec.c_str(), + N, + C, + H, + W, + layout_in.c_str(), + ms); + if(ms < 0) + printf("------------------------------------not " + "supported-------------------------------------\n"); + fflush(stdout); + + if(ms < 0) + { + return false; + } + + y_dev.FromDevice(y_host.data()); + + bool rtn = true; + if(validate) + { + // this host buffer will not copy to GPU, so no need use stride + ck_tile::HostTensor y_ref( + {dim_out[0], dim_out[1], dim_out[2], dim_out[3]}, + {stride_dim_out[0], stride_dim_out[1], stride_dim_out[2], stride_dim_out[3]}); + + ck_tile::reference_batched_transpose(x_host, y_ref, layout_in, layout_out); + + auto [rtol, atol] = get_elimit(""); + + rtn &= ck_tile::check_err( + y_host, y_ref, std::string("y Error: Incorrect results!"), rtol, atol); + } + printf("-----------------------------------------------------------------------valid:%s--------" + "--------------------------------------------------------------------\n", + rtn ? "y" : "n"); + fflush(stdout); + return rtn; +} + +template +bool run_test_case(int argc, char** argv) +{ + auto [result, args] = create_args(argc, argv); + if(!result) + return false; + + return run_batched_transpose(args); +} + +template +bool run_test_cases(std::vector>& test_cases) +{ + bool valid = true; + for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx) + { + constexpr int num_args = 7; + char* argv[num_args]; + + assert(test_cases[test_idx].size() == num_args && + "invalid number of arguments in test case"); + + for(std::size_t idx = 0; idx < test_cases[test_idx].size(); ++idx) + { + argv[idx] = test_cases[test_idx][idx].data(); + } + + valid = valid && run_test_case(num_args, argv); + + if(!valid) + break; + } + + return valid; +} + +std::vector> generate_test_cases(const std::string prec) +{ + return { + {"-pr=" + prec, "-N=1", "-C=32", "-H=1", "-W=32", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=64", "-H=1", "-W=64", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=2", "-C=12", "-H=1", "-W=32", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=3", "-C=1334", "-H=1", "-W=37", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=4", "-C=27", "-H=1", "-W=32", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=5", "-C=1234", "-H=1", "-W=12", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=1", "-H=1", "-W=1", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=1", "-H=1", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, + "-N=128", + "-C=1024", + "-H=64", + "-W=64", + "-layout_in=NCHW", + "-layout_out=NHWC"}, + {"-pr=" + prec, + "-N=128", + "-C=1024", + "-H=64", + "-W=64", + "-layout_in=NHWC", + "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=16", "-C=64", "-H=32", "-W=128", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=16", "-C=64", "-H=128", "-W=32", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=1", "-C=2048", "-H=1", "-W=1", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=2048", "-H=1", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, + "-N=1", + "-C=1", + "-H=1024", + "-W=1024", + "-layout_in=NCHW", + "-layout_out=NHWC"}, + {"-pr=" + prec, + "-N=1", + "-C=1", + "-H=1024", + "-W=1024", + "-layout_in=NHWC", + "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=8", "-C=16", "-H=8", "-W=16", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=8", "-C=16", "-H=8", "-W=16", "-layout_in=NHWC", "-layout_out=NCHW"}, + {"-pr=" + prec, "-N=1", "-C=64", "-H=1", "-W=1024", "-layout_in=NCHW", "-layout_out=NHWC"}, + {"-pr=" + prec, "-N=1", "-C=64", "-H=1024", "-W=1", "-layout_in=NHWC", "-layout_out=NCHW"}}; +} diff --git a/test/ck_tile/batched_transpose/batched_transpose_api.cpp b/test/ck_tile/batched_transpose/batched_transpose_api.cpp new file mode 100644 index 0000000000..27c2269a06 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose_api.cpp @@ -0,0 +1,113 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "batched_transpose.hpp" + +template +float batched_transpose_dispatch(batched_transpose_kargs& a, ck_tile::stream_config& s) +{ + uint32_t dim_stride = a.height * a.width; + + a.dim_stride = dim_stride; + a.dim_block_h = block_y; + a.dim_block_w = block_x; + + using block_tile = ck_tile::sequence; + using warp_tile = ck_tile::sequence; + using thread_tile = ck_tile::sequence; + + using ts_problem = + ck_tile::BatchedTransposeProblem; + using ts_pipeline = ck_tile::BatchedTransposePipeline; + + using kernel = ck_tile::BatchedTransposeKernel; + + auto kargs = kernel::MakeKargs(a); + + const dim3 grids = kernel::GridSize(a); + constexpr dim3 blocks = kernel::BlockSize(); + + printf("Grid: %u %u %u\n", grids.x, grids.y, grids.z); + printf("Block: %u %u %u\n", blocks.x, blocks.y, blocks.z); + printf("kargs: kargs.batch %d kargs.height %d kargs.width %d kargs.dim_strid %d\n", + kargs.batch, + kargs.height, + kargs.width, + kargs.dim_stride); + + printf("Launching Kernel...\n"); + + float ave_time = ck_tile::launch_kernel( + s, ck_tile::make_kernel(kernel{}, grids, blocks, 0, kargs)); + + printf("Kernel finished...\n"); + + return ave_time; +} + +// Param Comb: type_size, block_x & y, warp_x & y, thread_x & y +#define FOREACH_TRANSPOSE_PARAM(F) \ + F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(fp8, ck_tile::fp8_t, 64, 64, 64, 64, 8, 8, false, false) \ + F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(fp16, ck_tile::fp16_t, 64, 64, 64, 64, 8, 8, false, false) \ + F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, true, true) \ + F(bf16, ck_tile::bf16_t, 64, 64, 64, 64, 8, 8, false, false) + +// Macro that defines one static function per line +#define GEN_TRANSPOSE_FN(SHORT_NAME, REAL_TYPE, BX, BY, WX, WY, TX, TY, PADM, PADN) \ + static float \ + transpose_fn_##SHORT_NAME##_##BX##_##BY##_##WX##_##WY##_##TX##_##TY##_##PADM##_##PADN( \ + batched_transpose_kargs& a, ck_tile::stream_config& s) \ + { \ + return batched_transpose_dispatch(a, s); \ + } + +FOREACH_TRANSPOSE_PARAM(GEN_TRANSPOSE_FN) + +float batched_transpose(batched_transpose_trait t, + batched_transpose_kargs a, + ck_tile::stream_config s) +{ + if(t.type == "fp8") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp8_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_fp8_64_64_64_64_8_8_true_true(a, s); + } + } + else if(t.type == "fp16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_fp16_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_fp16_64_64_64_64_8_8_true_true(a, s); + } + } + else if(t.type == "bf16") + { + if(a.height % 64 == 0 && a.width % 64 == 0) + { + return transpose_fn_bf16_64_64_64_64_8_8_false_false(a, s); + } + else + { + return transpose_fn_bf16_64_64_64_64_8_8_true_true(a, s); + } + } + return -1; +} diff --git a/test/ck_tile/batched_transpose/batched_transpose_bf16.cpp b/test/ck_tile/batched_transpose/batched_transpose_bf16.cpp new file mode 100644 index 0000000000..42642335f6 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose_bf16.cpp @@ -0,0 +1,10 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "batched_transpose.inc" + +int main() +{ + std::vector> test_cases = generate_test_cases("bf16"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/batched_transpose/batched_transpose_fp16.cpp b/test/ck_tile/batched_transpose/batched_transpose_fp16.cpp new file mode 100644 index 0000000000..5562dd54e8 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose_fp16.cpp @@ -0,0 +1,10 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "batched_transpose.inc" + +int main() +{ + std::vector> test_cases = generate_test_cases("fp16"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/batched_transpose/batched_transpose_fp8.cpp b/test/ck_tile/batched_transpose/batched_transpose_fp8.cpp new file mode 100644 index 0000000000..45e79fb4c2 --- /dev/null +++ b/test/ck_tile/batched_transpose/batched_transpose_fp8.cpp @@ -0,0 +1,10 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT +#include "batched_transpose.inc" + +int main() +{ + std::vector> test_cases = generate_test_cases("fp8"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/smoothquant/CMakeLists.txt b/test/ck_tile/smoothquant/CMakeLists.txt new file mode 100644 index 0000000000..de4459051c --- /dev/null +++ b/test/ck_tile/smoothquant/CMakeLists.txt @@ -0,0 +1,28 @@ +# Currently ck_tile is only built on gfx9 +if(GPU_TARGETS MATCHES "gfx9") + function (add_smoothquant_test TARGET_NAME MAIN_SRC) + message(DEBUG "adding ${TARGET_NAME}") + + add_test_executable(${TARGET_NAME} ${MAIN_SRC}) + target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + + foreach(source IN LISTS ARGN) + list(APPEND INSTANCE_SRCS ${source}) + endforeach() + + target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) + + set(COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + + target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) + endfunction(add_smoothquant_test TARGET_NAME MAIN_SRC) + + file(GLOB INSTANCE_SRCS instances/*.cpp) + add_smoothquant_test(test_ck_tile_smoothquant_fp16 smoothquant_fp16.cpp ${INSTANCE_SRCS}) + add_smoothquant_test(test_ck_tile_smoothquant_bf16 smoothquant_bf16.cpp ${INSTANCE_SRCS}) + +else() + message(DEBUG "Skipping ck_tile smoothquant tests for current target") +endif() diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n1024_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n1024_instance.cpp new file mode 100644 index 0000000000..8e64d933f5 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n1024_instance.cpp @@ -0,0 +1,21 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +#if 0 +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +template float smoothquant_>(const S&, A); +#endif + +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n1536_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n1536_instance.cpp new file mode 100644 index 0000000000..0b8c3738b1 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n1536_instance.cpp @@ -0,0 +1,12 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n2048_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n2048_instance.cpp new file mode 100644 index 0000000000..1c805c540a --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n2048_instance.cpp @@ -0,0 +1,13 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n256_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n256_instance.cpp new file mode 100644 index 0000000000..0d6707d02c --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n256_instance.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n3072_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n3072_instance.cpp new file mode 100644 index 0000000000..abeba019fb --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n3072_instance.cpp @@ -0,0 +1,13 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n4096_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n4096_instance.cpp new file mode 100644 index 0000000000..be192b3122 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n4096_instance.cpp @@ -0,0 +1,13 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..5d7abd3635 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp @@ -0,0 +1,13 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n512_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n512_instance.cpp new file mode 100644 index 0000000000..faccdd9718 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n512_instance.cpp @@ -0,0 +1,12 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000..8ec7432168 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_bf16_n768_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n768_instance.cpp new file mode 100644 index 0000000000..ae7b6055b0 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_bf16_n768_instance.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n1024_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n1024_instance.cpp new file mode 100644 index 0000000000..dfe3e9cc9c --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n1024_instance.cpp @@ -0,0 +1,21 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +#if 0 +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +template float smoothquant_>(const S&, A); +#endif + +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n1536_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n1536_instance.cpp new file mode 100644 index 0000000000..a84c3ce0ef --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n1536_instance.cpp @@ -0,0 +1,12 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n2048_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n2048_instance.cpp new file mode 100644 index 0000000000..c38fc38438 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n2048_instance.cpp @@ -0,0 +1,13 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n256_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n256_instance.cpp new file mode 100644 index 0000000000..a2f8588511 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n256_instance.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n3072_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n3072_instance.cpp new file mode 100644 index 0000000000..99257bc322 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n3072_instance.cpp @@ -0,0 +1,13 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n4096_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n4096_instance.cpp new file mode 100644 index 0000000000..dec70cefb2 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n4096_instance.cpp @@ -0,0 +1,13 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..b85e864523 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp @@ -0,0 +1,13 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n512_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n512_instance.cpp new file mode 100644 index 0000000000..8d64ae043f --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n512_instance.cpp @@ -0,0 +1,12 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000..4675a31c25 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fp16_n768_instance.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n768_instance.cpp new file mode 100644 index 0000000000..f0f71fa717 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fp16_n768_instance.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/test/ck_tile/smoothquant/instances/smoothquant_fwd_api.cpp b/test/ck_tile/smoothquant/instances/smoothquant_fwd_api.cpp new file mode 100644 index 0000000000..4b7ef5a38d --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_fwd_api.cpp @@ -0,0 +1,143 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "smoothquant.hpp" + +template +using trait_ = smoothquant_traits_; + +template +float smoothquant_dispatch(smoothquant_traits /*t*/, + smoothquant_args a, + const ck_tile::stream_config& s) +{ + float r = -1; + // clang-format off + // rm rn tm tn vn pd 2p + if(a.n <= 64) { + r = smoothquant_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + return r; + // clang-format on +} + +float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + return smoothquant_dispatch(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return smoothquant_dispatch(t, a, s); + } + else + throw std::runtime_error("Without supported instances!"); +} diff --git a/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp new file mode 100644 index 0000000000..19310beb94 --- /dev/null +++ b/test/ck_tile/smoothquant/instances/smoothquant_instance_common.hpp @@ -0,0 +1,61 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include +#include "smoothquant.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = smoothquant_args; + +template +using trait_ = smoothquant_traits_; + +template +float smoothquant_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = ck_tile::SmoothquantPipelineProblem< + typename SmoothquantTypeConfig::XDataType, + typename SmoothquantTypeConfig::SmoothScaleDataType, + typename SmoothquantTypeConfig::ComputeDataType, + typename SmoothquantTypeConfig::YScaleDataType, + typename SmoothquantTypeConfig::QYDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kTwoPass>; + + using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass; + using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::Smoothquant; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/test/ck_tile/smoothquant/smoothquant.hpp b/test/ck_tile/smoothquant/smoothquant.hpp new file mode 100644 index 0000000000..ce9ab25448 --- /dev/null +++ b/test/ck_tile/smoothquant/smoothquant.hpp @@ -0,0 +1,114 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/smoothquant.hpp" +#include + +template +struct SmoothquantTypeConfig; + +template <> +struct SmoothquantTypeConfig +{ + using XDataType = ck_tile::half_t; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct SmoothquantTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using SmoothScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +// runtime args +struct smoothquant_args : public ck_tile::SmoothquantHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct smoothquant_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps * (ck_tile::get_warp_size() / ThreadPerBlock_N_); + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +template +float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a); + +// This is the public API, will be generated by script +struct smoothquant_traits +{ + std::string data_type; +}; + +float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&); diff --git a/test/ck_tile/smoothquant/smoothquant.inc b/test/ck_tile/smoothquant/smoothquant.inc new file mode 100644 index 0000000000..afda7de4eb --- /dev/null +++ b/test/ck_tile/smoothquant/smoothquant.inc @@ -0,0 +1,274 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "ck_tile/host.hpp" +#include "smoothquant.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[], int index = 0) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("x_stride", "-1", "input stride per row, if -1 then equal to n") + .insert("y_stride", "-1", "output stride per row, if -1 then equal to n") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv, index); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); + if(x_stride < 0) + x_stride = n; + ck_tile::index_t y_stride = arg_parser.get_int("y_stride"); + if(y_stride < 0) + y_stride = n; + std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(x_stride >= n); + + using TypeConfig = SmoothquantTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using SmoothScaleDataType = typename TypeConfig::SmoothScaleDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({m, n}, {x_stride, 1}); + ck_tile::HostTensor smscale_host({n}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {y_stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {y_stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(smscale_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem smscale_buf(smscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + smscale_buf.ToDevice(smscale_host.data()); + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", y_stride:" << y_stride + << std::flush; + + smoothquant_traits traits{data_type}; + + smoothquant_args args{x_buf.GetDeviceBuffer(), + smscale_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + m, + n, + x_stride, + y_stride}; + + float ave_time = smoothquant( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(SmoothScaleDataType) * n + + sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + ck_tile::HostTensor y_host({m, n}, {y_stride, 1}); + // smooth outlier + { + auto f = [&](auto n_) { + auto v_smscale = ck_tile::type_convert(smscale_host(n_)); + + for(int m_ = 0; m_ < m; ++m_) + { + auto v_x = ck_tile::type_convert(x_host(m_, n_)); + y_host(m_, n_) = v_x * v_smscale; + } + }; + + ck_tile::make_ParallelTensorFunctor(f, smscale_host.get_element_space_size())( + std::thread::hardware_concurrency()); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(y_stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * y_stride, + qy_host_dev.begin() + i_r * y_stride + + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * y_stride, + qy_host_ref.begin() + i_r * y_stride + + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +std::vector> create_test_cases(const std::string prec) +{ + return {{"-prec=" + prec, "-m=99", "-n=13", "-x_stride=-1"}, + {"-prec=" + prec, "-m=17", "-n=16", "-x_stride=-1"}, + {"-prec=" + prec, "-m=1", "-n=100", "-x_stride=-1"}, + {"-prec=" + prec, "-m=4", "-n=128", "-x_stride=-1"}, + {"-prec=" + prec, "-m=80", "-n=127", "-x_stride=-1"}, + {"-prec=" + prec, "-m=22", "-n=255", "-x_stride=256"}, + {"-prec=" + prec, "-m=7", "-n=599", "-x_stride=-1"}, + {"-prec=" + prec, "-m=19", "-n=512", "-x_stride=-1"}, + {"-prec=" + prec, "-m=33", "-n=313", "-x_stride=1000"}, + {"-prec=" + prec, "-m=11", "-n=510", "-x_stride=-1"}, + {"-prec=" + prec, "-m=171", "-n=676", "-x_stride=818"}, + {"-prec=" + prec, "-m=91", "-n=636", "-x_stride=-1"}, + {"-prec=" + prec, "-m=12", "-n=768", "-x_stride=800"}, + {"-prec=" + prec, "-m=100", "-n=766", "-x_stride=812"}, + {"-prec=" + prec, "-m=31", "-n=1024", "-x_stride=-1"}, + {"-prec=" + prec, "-m=64", "-n=1000", "-x_stride=1004"}, + {"-prec=" + prec, "-m=8", "-n=1501", "-x_stride=-1"}, + {"-prec=" + prec, "-m=3", "-n=1826", "-x_stride=-1"}, + {"-prec=" + prec, "-m=5", "-n=2040", "-x_stride=-1"}, + {"-prec=" + prec, "-m=7", "-n=2734", "-x_stride=-1"}, + {"-prec=" + prec, "-m=1", "-n=3182", "-x_stride=-1"}, + {"-prec=" + prec, "-m=9", "-n=4096", "-x_stride=-1"}, + {"-prec=" + prec, "-m=3", "-n=8192", "-x_stride=-1"}, + {"-prec=" + prec, "-m=1", "-n=10547", "-x_stride=-1"}, + {"-prec=" + prec, "-m=3", "-n=17134", "-x_stride=-1"}}; +} + +template +bool run_test_case(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return false; + + return run(arg_parser); +} + +template +bool run_test_cases(std::vector>& test_cases) +{ + bool valid = true; + constexpr int num_args = 4; + + char* argv[num_args]; + + for(std::size_t test_idx = 0; test_idx < test_cases.size(); ++test_idx) + { + assert(test_cases[test_idx].size() == num_args && + "invalid number of arguments in test case"); + for(std::size_t idx = 0; idx < num_args; ++idx) + { + argv[idx] = test_cases[test_idx][idx].data(); + } + valid = valid && run_test_case(num_args, argv); + + if(!valid) + break; + } + + return valid; +} diff --git a/test/ck_tile/smoothquant/smoothquant_bf16.cpp b/test/ck_tile/smoothquant/smoothquant_bf16.cpp new file mode 100644 index 0000000000..4f5a8ac63e --- /dev/null +++ b/test/ck_tile/smoothquant/smoothquant_bf16.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant.inc" + +int main() +{ + std::vector> test_cases = create_test_cases("bf16"); + + return !run_test_cases(test_cases); +} diff --git a/test/ck_tile/smoothquant/smoothquant_fp16.cpp b/test/ck_tile/smoothquant/smoothquant_fp16.cpp new file mode 100644 index 0000000000..7d822b4903 --- /dev/null +++ b/test/ck_tile/smoothquant/smoothquant_fp16.cpp @@ -0,0 +1,11 @@ +// Copyright © Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#include "smoothquant.inc" + +int main() +{ + std::vector> test_cases = create_test_cases("fp16"); + + return !run_test_cases(test_cases); +}