From 4faf3ab5871573fb4a18519c71731069b69ac125 Mon Sep 17 00:00:00 2001 From: rocking Date: Fri, 1 Nov 2024 13:51:56 +0800 Subject: [PATCH] [Ck_tile] smoothquant (#1617) * fix compile error * fix typo of padding * Add smoothquant op * Add smoothquant instance library * refine type * add test script * Re-generate smoothquant.hpp * Always use 'current year' in copyright * use Generic2dBlockShape instead * Add vector = 8 instance back * Find exe path automatically * Simplify the api condition * Remove debugging code * update year * Add blank line between function declaration * explicitly cast return value to dim3 * refine return value * Fix default warmup and repeat value * Add comment * refactor sommthquant cmake * Add README * Fix typo --------- Co-authored-by: Po Yen, Chen [ROCm/composable_kernel commit: fbd654545a2644f99c3e7a493ebcc2169938583b] --- .../02_layernorm2d/script/perf_test.sh | 5 +- .../02_layernorm2d/script/smoke_test.sh | 3 +- .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 2 +- .../instances/rmsnorm2d_fwd_api.cpp | 9 +- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp | 2 +- .../ck_tile/10_rmsnorm2d/script/perf_test.sh | 5 +- .../ck_tile/10_rmsnorm2d/script/smoke_test.sh | 3 +- .../add_rmsnorm2d_rdquant_fwd.hpp | 6 +- .../example_add_rmsnorm2d_rdquant_fwd.cpp | 8 +- .../add_rmsnorm2d_rdquant_fwd_api.cpp | 9 +- .../script/perf_test.sh | 5 +- .../script/smoke_test.sh | 3 +- example/ck_tile/12_smoothquant/CMakeLists.txt | 24 ++ example/ck_tile/12_smoothquant/README.md | 21 ++ .../12_smoothquant/example_smoothquant.cpp | 237 ++++++++++++++++++ .../smoothquant_bf16_n1024_instance.cpp | 22 ++ .../smoothquant_bf16_n1536_instance.cpp | 13 + .../smoothquant_bf16_n2048_instance.cpp | 14 ++ .../smoothquant_bf16_n256_instance.cpp | 12 + .../smoothquant_bf16_n3072_instance.cpp | 14 ++ .../smoothquant_bf16_n4096_instance.cpp | 14 ++ .../smoothquant_bf16_n4096_tp_instance.cpp | 14 ++ .../smoothquant_bf16_n512_instance.cpp | 13 + .../smoothquant_bf16_n64_n128_instance.cpp | 12 + .../smoothquant_bf16_n768_instance.cpp | 12 + .../smoothquant_fp16_n1024_instance.cpp | 22 ++ .../smoothquant_fp16_n1536_instance.cpp | 13 + .../smoothquant_fp16_n2048_instance.cpp | 14 ++ .../smoothquant_fp16_n256_instance.cpp | 12 + .../smoothquant_fp16_n3072_instance.cpp | 14 ++ .../smoothquant_fp16_n4096_instance.cpp | 14 ++ .../smoothquant_fp16_n4096_tp_instance.cpp | 14 ++ .../smoothquant_fp16_n512_instance.cpp | 13 + .../smoothquant_fp16_n64_n128_instance.cpp | 12 + .../smoothquant_fp16_n768_instance.cpp | 12 + .../instances/smoothquant_fwd_api.cpp | 143 +++++++++++ .../instances/smoothquant_instance_common.hpp | 62 +++++ .../12_smoothquant/script/perf_test.sh | 37 +++ .../12_smoothquant/script/smoke_test.sh | 30 +++ .../ck_tile/12_smoothquant/smoothquant.cpp | 218 ++++++++++++++++ .../ck_tile/12_smoothquant/smoothquant.hpp | 114 +++++++++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 1 - .../add_rmsnorm2d_rdquant_fwd_kernel.hpp | 17 +- .../add_rmsnorm2d_rdquant_fwd_shape.hpp | 78 ------ ...2d_rdquant_fwd_pipeline_default_policy.hpp | 1 + .../kernel/layernorm2d_fwd_kernel.hpp | 4 +- ...ayernorm2d_fwd_pipeline_default_policy.hpp | 1 + .../layernorm2d_fwd_pipeline_problem.hpp | 2 +- .../pipeline/layernorm2d_fwd_traits.hpp | 2 +- .../ops/reduce/block/block_reduce2d.hpp | 3 +- include/ck_tile/ops/rmsnorm2d.hpp | 1 - .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 12 +- .../rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp | 78 ------ .../rmsnorm2d_fwd_pipeline_default_policy.hpp | 1 + include/ck_tile/ops/smoothquant.hpp | 12 + .../smoothquant/kernel/smoothquant_kernel.hpp | 176 +++++++++++++ .../smoothquant_pipeline_default_policy.hpp | 95 +++++++ .../smoothquant_pipeline_one_pass.hpp | 94 +++++++ .../pipeline/smoothquant_pipeline_problem.hpp | 35 +++ .../smoothquant_pipeline_two_pass.hpp | 132 ++++++++++ include/ck_tile/remod.py | 5 +- 62 files changed, 1758 insertions(+), 219 deletions(-) create mode 100644 example/ck_tile/12_smoothquant/CMakeLists.txt create mode 100644 example/ck_tile/12_smoothquant/README.md create mode 100644 example/ck_tile/12_smoothquant/example_smoothquant.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1536_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n2048_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n256_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n3072_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n512_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n768_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1536_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n2048_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n256_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n3072_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n512_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp create mode 100644 example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp create mode 100755 example/ck_tile/12_smoothquant/script/perf_test.sh create mode 100755 example/ck_tile/12_smoothquant/script/smoke_test.sh create mode 100644 example/ck_tile/12_smoothquant/smoothquant.cpp create mode 100644 example/ck_tile/12_smoothquant/smoothquant.hpp delete mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp delete mode 100644 include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp create mode 100644 include/ck_tile/ops/smoothquant.hpp create mode 100644 include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp create mode 100644 include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp create mode 100644 include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp diff --git a/example/ck_tile/02_layernorm2d/script/perf_test.sh b/example/ck_tile/02_layernorm2d/script/perf_test.sh index a34624536c..5a34e19280 100755 --- a/example/ck_tile/02_layernorm2d/script/perf_test.sh +++ b/example/ck_tile/02_layernorm2d/script/perf_test.sh @@ -1,6 +1,5 @@ - -# run from top of ck folder -EXE=build/bin/tile_example_layernorm2d_fwd +#!/bin/sh +EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" $EXE -m=1 -n=1 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 $EXE -m=700 -n=80 -e=1e-12 -v=1 -prec_i=bf16 -repeat=1000 diff --git a/example/ck_tile/02_layernorm2d/script/smoke_test.sh b/example/ck_tile/02_layernorm2d/script/smoke_test.sh index d56406b6f2..b7fd354bb8 100755 --- a/example/ck_tile/02_layernorm2d/script/smoke_test.sh +++ b/example/ck_tile/02_layernorm2d/script/smoke_test.sh @@ -1,6 +1,5 @@ #!/bin/sh -# call from top of CK folder -EXE=./build/bin/tile_example_layernorm2d_fwd +EXE="$(find . -name tile_example_layernorm2d_fwd -type f | head -n 1)" for fquant in "" "-fquant=1 -prec_o=int8"; do for pr_i in "fp16" "bf16" ; do diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp index bb2c949015..34df7b74fa 100644 --- a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -69,7 +69,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using WarpTile = ck_tile::sequence<1, 64>; using Vector = ck_tile::sequence<1, 1>; - using Shape = ck_tile::Rmsnorm2dShape; + using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem>(s, a); } return r; -#else - return rmsnorm2d_fwd_>(s, a); -#endif // clang-format on } float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s) { - float r = -1; if(t.data_type.compare("fp16") == 0) { return rmsnorm2d_fwd_b16_(t, a, s); @@ -146,8 +141,6 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile: { return rmsnorm2d_fwd_b16_(t, a, s); } - if(r < 0) + else throw std::runtime_error("Without supported instances!"); - - return r; } diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp index 756ecb2c4b..b4d429d46f 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -97,7 +97,7 @@ struct rmsnorm2d_fwd_traits_ using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; - using Shape = ck_tile::Rmsnorm2dShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveInvRms = kSaveInvRms_; diff --git a/example/ck_tile/10_rmsnorm2d/script/perf_test.sh b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh index f3cfcc4b89..7b9d0820fd 100755 --- a/example/ck_tile/10_rmsnorm2d/script/perf_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh @@ -1,6 +1,5 @@ - -# run from top of ck folder -EXE=build/bin/tile_rmsnorm2d_fwd +#!/bin/sh +EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" $EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 6ec5e846ce..758d6de546 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -1,6 +1,5 @@ #!/bin/sh -# call from top of CK folder -EXE=./build/bin/tile_rmsnorm2d_fwd +EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" for pr_i in "fp16" "bf16" ; do $EXE -prec=$pr_i -m=99 -n=13 diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp index bf70d9d23f..443b9b1024 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -18,7 +18,7 @@ struct AddRmsnormRdquantTypeConfig using BDataType = ck_tile::half_t; using GammaDataType = ck_tile::half_t; using XDataType = ck_tile::half_t; - using YScaleDataType = ck_tile::half_t; + using YScaleDataType = float; using QYDataType = ck_tile::int8_t; using ComputeDataType = float; }; @@ -30,7 +30,7 @@ struct AddRmsnormRdquantTypeConfig using BDataType = ck_tile::bf16_t; using GammaDataType = ck_tile::bf16_t; using XDataType = ck_tile::bf16_t; - using YScaleDataType = ck_tile::bf16_t; + using YScaleDataType = float; using QYDataType = ck_tile::int8_t; using ComputeDataType = float; }; @@ -101,7 +101,7 @@ struct add_rmsnorm2d_rdquant_fwd_traits_ using WarpTile = ck_tile::sequence; using Vector = ck_tile::sequence<1, Vector_N_>; - using Shape = ck_tile::AddRmsnorm2dRdquantShape; + using Shape = ck_tile::Generic2dBlockShape; static constexpr bool kPadN = kPadN_; static constexpr bool kSaveX = kSaveX_; diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp index 40fabf7f55..ada4c6f2da 100644 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -66,7 +66,7 @@ bool run(const ck_tile::ArgParser& arg_parser) using BDataType = DataType; using GammaDataType = DataType; using XDataType = DataType; - using YScaleDataType = DataType; + using YScaleDataType = float; using QYDataType = ck_tile::int8_t; using ComputeDataType = float; @@ -99,12 +99,12 @@ bool run(const ck_tile::ArgParser& arg_parser) constexpr bool kThreePass = true; - using BlockWarps = ck_tile::sequence<2, 2>; - using BlockTile = ck_tile::sequence<2, 128>; + using BlockWarps = ck_tile::sequence<4, 1>; + using BlockTile = ck_tile::sequence<4, 128>; using WarpTile = ck_tile::sequence<1, 64>; using Vector = ck_tile::sequence<1, 1>; - using Shape = ck_tile::AddRmsnorm2dRdquantShape; + using Shape = ck_tile::Generic2dBlockShape; using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem>(s, a); } return r; -#else - return add_rmsnorm2d_rdquant_fwd_>(s, a); -#endif // clang-format on } @@ -139,7 +135,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, const ck_tile::stream_config& s) { - float r = -1; // Only support instance of save_x == true for now assert(t.save_x); if(t.data_type.compare("fp16") == 0) @@ -150,8 +145,6 @@ float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, { return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); } - if(r < 0) + else throw std::runtime_error("Without supported instances!"); - - return r; } diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh index 11fd364886..d02b0bab33 100755 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh @@ -1,6 +1,5 @@ - -# run from top of ck folder -EXE=build/bin/tile_add_rmsnorm2d_rdquant_fwd +#!/bin/sh +EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)" $EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 $EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh index 4a02cdcb65..b60f5fcf20 100755 --- a/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh @@ -1,6 +1,5 @@ #!/bin/sh -# call from top of CK folder -EXE=./build/bin/tile_add_rmsnorm2d_rdquant_fwd +EXE="$(find . -name tile_add_rmsnorm2d_rdquant_fwd -type f | head -n 1)" for pr_i in "fp16" "bf16" ; do $EXE -prec=$pr_i -m=99 -n=13 diff --git a/example/ck_tile/12_smoothquant/CMakeLists.txt b/example/ck_tile/12_smoothquant/CMakeLists.txt new file mode 100644 index 0000000000..09a56c6dab --- /dev/null +++ b/example/ck_tile/12_smoothquant/CMakeLists.txt @@ -0,0 +1,24 @@ +function (add_smoothquant_example TARGET_NAME MAIN_SRC) + message("adding ${TARGET_NAME}") + # not using add_example_executable() to add target, since we don't want this to have + # to be included in "make all/install/check" + add_executable(${TARGET_NAME} EXCLUDE_FROM_ALL ${MAIN_SRC}) + target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) + + foreach(source IN LISTS ARGN) + list(APPEND INSTANCE_SRCS ${source}) + endforeach() + + target_sources(${TARGET_NAME} PRIVATE ${INSTANCE_SRCS}) + + set(COMPILE_OPTIONS) + # NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations + list(APPEND COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + + target_compile_options(${TARGET_NAME} PRIVATE ${COMPILE_OPTIONS}) +endfunction(add_smoothquant_example TARGET_NAME MAIN_SRC) + +file(GLOB INSTANCE_SRCS instances/*.cpp) + +add_smoothquant_example(tile_smoothquant smoothquant.cpp ${INSTANCE_SRCS}) +add_smoothquant_example(tile_example_smoothquant example_smoothquant.cpp) diff --git a/example/ck_tile/12_smoothquant/README.md b/example/ck_tile/12_smoothquant/README.md new file mode 100644 index 0000000000..d6b815f8cf --- /dev/null +++ b/example/ck_tile/12_smoothquant/README.md @@ -0,0 +1,21 @@ +# smoothquant + +This folder contains example for smoothquant using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_smoothquant -j +``` +This will result in an executable `build/bin/tile_smoothquant` + +## cmdline +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` diff --git a/example/ck_tile/12_smoothquant/example_smoothquant.cpp b/example/ck_tile/12_smoothquant/example_smoothquant.cpp new file mode 100644 index 0000000000..3a26eb6a77 --- /dev/null +++ b/example/ck_tile/12_smoothquant/example_smoothquant.cpp @@ -0,0 +1,237 @@ +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/smoothquant.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using XDataType = DataType; + using XScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor xscale_host({n}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + xscale_buf.ToDevice(xscale_host.data()); + + constexpr bool kTwoPass = true; + + using BlockWarps = ck_tile::sequence<2, 2>; + using BlockTile = ck_tile::sequence<2, 128>; + using WarpTile = ck_tile::sequence<1, 64>; + using Vector = ck_tile::sequence<1, 1>; + + using Shape = ck_tile::Generic2dBlockShape; + using Problem = ck_tile::SmoothquantPipelineProblem; + + using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass; + using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass; + using Pipeline = std::conditional_t; + using Kernel = ck_tile::Smoothquant; + + ck_tile::SmoothquantHostArgs args{x_buf.GetDeviceBuffer(), + xscale_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + m, + n, + stride}; + + auto kargs = Kernel::MakeKargs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + auto s = ck_tile::stream_config{nullptr, true, 1, warmup, repeat}; + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + ck_tile::HostTensor y_host({m, n}, {stride, 1}); + // smooth outlier + { + auto f = [&](auto n_) { + auto v_xscale = ck_tile::type_convert(xscale_host(n_)); + + for(int m_ = 0; m_ < m; ++m_) + { + auto v_x = ck_tile::type_convert(x_host(m_, n_)); + y_host(m_, n_) = v_x * v_xscale; + } + }; + + ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( + std::thread::hardware_concurrency()); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + /*else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + }*/ + + return -3; +} diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp new file mode 100644 index 0000000000..b25361da2f --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +#if 0 +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +template float smoothquant_>(const S&, A); +#endif + +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1536_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1536_instance.cpp new file mode 100644 index 0000000000..0a332fe410 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n2048_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n2048_instance.cpp new file mode 100644 index 0000000000..bdf5804e43 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n256_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n256_instance.cpp new file mode 100644 index 0000000000..774c977f2e --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n3072_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n3072_instance.cpp new file mode 100644 index 0000000000..c571ef443e --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_instance.cpp new file mode 100644 index 0000000000..80e4b3a296 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..7f776a6e46 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n512_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n512_instance.cpp new file mode 100644 index 0000000000..12bc90b669 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000..1cee186063 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n768_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n768_instance.cpp new file mode 100644 index 0000000000..aca7f7eb4e --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_bf16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp new file mode 100644 index 0000000000..be5fecaca1 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +#if 0 +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +template float smoothquant_>(const S&, A); +#endif + +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1536_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1536_instance.cpp new file mode 100644 index 0000000000..59fe148750 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n2048_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n2048_instance.cpp new file mode 100644 index 0000000000..a3710a6ab4 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n256_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n256_instance.cpp new file mode 100644 index 0000000000..2b1bca7aa4 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n3072_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n3072_instance.cpp new file mode 100644 index 0000000000..205ba130e4 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_instance.cpp new file mode 100644 index 0000000000..96503ac913 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..36e5e0bb14 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n512_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n512_instance.cpp new file mode 100644 index 0000000000..f09932e295 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000..023cd0be6e --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp new file mode 100644 index 0000000000..5dcf560c74 --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fp16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "smoothquant_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd 2p +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +template float smoothquant_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp b/example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp new file mode 100644 index 0000000000..962755f6ef --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_fwd_api.cpp @@ -0,0 +1,143 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "smoothquant.hpp" + +template +using trait_ = smoothquant_traits_; + +template +float smoothquant_dispatch(smoothquant_traits /*t*/, + smoothquant_args a, + const ck_tile::stream_config& s) +{ + float r = -1; + // clang-format off + // rm rn tm tn vn pd 2p + if(a.n <= 64) { + r = smoothquant_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = smoothquant_>(s, a); + else if (a.n % 4 == 0) + r = smoothquant_>(s, a); + else if (a.n % 2 == 0) + r = smoothquant_>(s, a); + else + r = smoothquant_>(s, a); + } + return r; + // clang-format on +} + +float smoothquant(smoothquant_traits t, smoothquant_args a, const ck_tile::stream_config& s) +{ + if(t.data_type.compare("fp16") == 0) + { + return smoothquant_dispatch(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return smoothquant_dispatch(t, a, s); + } + else + throw std::runtime_error("Without supported instances!"); +} diff --git a/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp new file mode 100644 index 0000000000..cdf93f6fcf --- /dev/null +++ b/example/ck_tile/12_smoothquant/instances/smoothquant_instance_common.hpp @@ -0,0 +1,62 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "smoothquant.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = smoothquant_args; + +template +using trait_ = smoothquant_traits_; + +template +float smoothquant_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = ck_tile::SmoothquantPipelineProblem< + typename SmoothquantTypeConfig::XDataType, + typename SmoothquantTypeConfig::XScaleDataType, + typename SmoothquantTypeConfig::ComputeDataType, + typename SmoothquantTypeConfig::YScaleDataType, + typename SmoothquantTypeConfig::QYDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kTwoPass>; + + using OnePassPipeline = ck_tile::SmoothquantPipelineOnePass; + using TwoPassPipeline = ck_tile::SmoothquantPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::Smoothquant; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/12_smoothquant/script/perf_test.sh b/example/ck_tile/12_smoothquant/script/perf_test.sh new file mode 100755 index 0000000000..741eb32ec1 --- /dev/null +++ b/example/ck_tile/12_smoothquant/script/perf_test.sh @@ -0,0 +1,37 @@ + +EXE="$(find . -name tile_smoothquant -type f | head -n 1)" + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/12_smoothquant/script/smoke_test.sh b/example/ck_tile/12_smoothquant/script/smoke_test.sh new file mode 100755 index 0000000000..d08e063966 --- /dev/null +++ b/example/ck_tile/12_smoothquant/script/smoke_test.sh @@ -0,0 +1,30 @@ +#!/bin/sh +EXE="$(find . -name tile_smoothquant -type f | head -n 1)" + +for pr_i in "fp16" "bf16" ; do +$EXE -prec=$pr_i -m=99 -n=13 +$EXE -prec=$pr_i -m=17 -n=16 +$EXE -prec=$pr_i -m=1 -n=100 +$EXE -prec=$pr_i -m=4 -n=128 +$EXE -prec=$pr_i -m=80 -n=127 +$EXE -prec=$pr_i -m=22 -n=255 -stride=256 +$EXE -prec=$pr_i -m=7 -n=599 +$EXE -prec=$pr_i -m=19 -n=512 +$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 +$EXE -prec=$pr_i -m=11 -n=510 +$EXE -prec=$pr_i -m=171 -n=676 -stride=818 +$EXE -prec=$pr_i -m=91 -n=636 +$EXE -prec=$pr_i -m=12 -n=768 -stride=800 +$EXE -prec=$pr_i -m=100 -n=766 -stride=812 +$EXE -prec=$pr_i -m=31 -n=1024 +$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 +$EXE -prec=$pr_i -m=8 -n=1501 +$EXE -prec=$pr_i -m=3 -n=1826 +$EXE -prec=$pr_i -m=5 -n=2040 +$EXE -prec=$pr_i -m=7 -n=2734 +$EXE -prec=$pr_i -m=1 -n=3182 +$EXE -prec=$pr_i -m=9 -n=4096 +$EXE -prec=$pr_i -m=3 -n=8192 +$EXE -prec=$pr_i -m=1 -n=10547 +$EXE -prec=$pr_i -m=3 -n=17134 +done diff --git a/example/ck_tile/12_smoothquant/smoothquant.cpp b/example/ck_tile/12_smoothquant/smoothquant.cpp new file mode 100644 index 0000000000..ed01d654fd --- /dev/null +++ b/example/ck_tile/12_smoothquant/smoothquant.cpp @@ -0,0 +1,218 @@ +#include "ck_tile/host.hpp" +#include "smoothquant.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-5; + double atol = 1e-5; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = SmoothquantTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using XScaleDataType = typename TypeConfig::XScaleDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor xscale_host({n}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{1e-3, .5f}(xscale_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem xscale_buf(xscale_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + xscale_buf.ToDevice(xscale_host.data()); + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + smoothquant_traits traits{data_type}; + + smoothquant_args args{x_buf.GetDeviceBuffer(), + xscale_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + m, + n, + stride}; + + float ave_time = smoothquant( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = sizeof(XDataType) * m * n + sizeof(XScaleDataType) * n + + sizeof(YScaleDataType) * m + sizeof(QYDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + ck_tile::HostTensor y_host({m, n}, {stride, 1}); + // smooth outlier + { + auto f = [&](auto n_) { + auto v_xscale = ck_tile::type_convert(xscale_host(n_)); + + for(int m_ = 0; m_ < m; ++m_) + { + auto v_x = ck_tile::type_convert(x_host(m_, n_)); + y_host(m_, n_) = v_x * v_xscale; + } + }; + + ck_tile::make_ParallelTensorFunctor(f, xscale_host.get_element_space_size())( + std::thread::hardware_concurrency()); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/12_smoothquant/smoothquant.hpp b/example/ck_tile/12_smoothquant/smoothquant.hpp new file mode 100644 index 0000000000..26a598db55 --- /dev/null +++ b/example/ck_tile/12_smoothquant/smoothquant.hpp @@ -0,0 +1,114 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/smoothquant.hpp" +#include + +template +struct SmoothquantTypeConfig; + +template <> +struct SmoothquantTypeConfig +{ + using XDataType = ck_tile::half_t; + using XScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct SmoothquantTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using XScaleDataType = float; + using YScaleDataType = float; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +// runtime args +struct smoothquant_args : public ck_tile::SmoothquantHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct smoothquant_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Generic2dBlockShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +template +float smoothquant_(const ck_tile::stream_config& s, smoothquant_args a); + +// This is the public API, will be generated by script +struct smoothquant_traits +{ + std::string data_type; +}; + +float smoothquant(smoothquant_traits, smoothquant_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index e404e5019e..9dd9a6ca3c 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -11,3 +11,4 @@ add_subdirectory(06_permute) add_subdirectory(09_topk_softmax) add_subdirectory(10_rmsnorm2d) add_subdirectory(11_add_rmsnorm2d_rdquant) +add_subdirectory(12_smoothquant) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp index fb8d7221b8..d06d8529ac 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp" -#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp index 4a0e290352..f06910db3d 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -9,15 +9,16 @@ namespace ck_tile { // host side args +// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale) struct AddRmsnorm2dRdquantFwdHostArgs { - const void* p_a; - const void* p_b; - const void* p_gamma; + const void* p_a; // [m ,n], input, fp16/bf16 + const void* p_b; // [m ,n], input, fp16/bf16 + const void* p_gamma; // [1, n], gamma, prec same as input - void* p_x; - void* p_yscale; - void* p_qy; + void* p_x; // [m, n], output, p_a + p_b, fp16/bf16 + void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of reuslt of rmsnorm2d(x) + void* p_qy; // [m, n], output, result of quant tensor of rmsnorm2d(x) int8 float epsilon; @@ -90,7 +91,7 @@ struct AddRmsnorm2dRdquantFwd CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) { - return integer_divide_ceil(hargs.m, Block_M); + return dim3(integer_divide_ceil(hargs.m, Block_M)); } CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } @@ -170,7 +171,7 @@ struct AddRmsnorm2dRdquantFwd number<1>{}); const auto tmp2_ = - pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); return make_tile_window(tmp2_, make_tuple(number{}), {0}); }(); diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp deleted file mode 100644 index 4bc7db434e..0000000000 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp +++ /dev/null @@ -1,78 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { -/* -// clang-format off - -4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector - - Block_N (Warp_N * WarpPerBlock_N * Repeat_N ) - +<----------------------< Repeat_N(2)>--------------------->+ - | | - +<-- -->+ - Warp_N - +--------------+--------------+--------------+--------------+----+----------------+ - Warp_M | wrap_0 | wrap_1 | | ^ ^ - +--------------+--------------+ | | - | wrap_2 | wrap_3 | | v - +--------------+--------------+--------------+--------------+----+ Block_M - | | | - + + | - | | | v - +--------------+--------------+--------------+--------------+ + - - each Warp-tile (e.g 16 thrd per row) - - Vector_N (contiguous pixels each thrd holds along N, or vector size) - +-----------+-----------+-----------+-----------+-----------+ - | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M - +-----------+-----------+-----------+-----------+-----------+ - | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ... - +-----------+-----------+-----------+-----------+-----------+ -// clang-format on -*/ -template - typename WarpPerBlock_, // num warps along seq - typename WarpTile_, // warp size, seq - typename Vector_, // contiguous pixels(vector size) along seq - index_t BlockSize_ = - warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> -struct AddRmsnorm2dRdquantShape -{ - // block size - static constexpr index_t Block_M = BlockTile_::at(number<0>{}); - static constexpr index_t Block_N = BlockTile_::at(number<1>{}); - - // num warps along seq, within each block - static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); - static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); - - // warp size - static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); - static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); - - static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); - static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); - // repeat of each thread along seq - static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); - static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); - - // vector size along seq - static constexpr index_t Vector_M = Vector_::at(number<0>{}); - static constexpr index_t Vector_N = Vector_::at(number<1>{}); - - static_assert(Warp_M % Vector_M == 0); - static_assert(Warp_N % Vector_N == 0); - // num of threads along seq, within each warp - static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; - - static constexpr index_t BlockSize = BlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp index 73ba633b15..0b9bae4e9e 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp @@ -26,6 +26,7 @@ struct AddRmsnorm2dRdquantFwdPipelineDefaultPolicy sequence<1, 1, 2, 2>, sequence<0, 3, 0, 3>>{}); } + template CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution() { diff --git a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp index 9a2e06d05f..f5a214ba57 100644 --- a/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp @@ -117,7 +117,7 @@ struct Layernorm2dFwd CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) { - return (hargs.m + Block_M - 1) / Block_M; + return dim3(integer_divide_ceil(hargs.m, Block_M)); } CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } @@ -165,7 +165,7 @@ struct Layernorm2dFwd return base_str; }(); - return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" + + return _SS_("layernorm2d_fwd_") + _SS_(prec_str) + "_" + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + _SS_(Pipeline::name) + surfix; diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp index 6661cddf43..02fd5f7b93 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp @@ -26,6 +26,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy sequence<1, 1, 2, 2>, sequence<0, 3, 0, 3>>{}); } + template CK_TILE_DEVICE static constexpr auto MakeGammaBetaBlockTileDistribution() { diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp index 7ec830add1..17ff80f471 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp index fb327f74a3..ed9e18be30 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp @@ -1,5 +1,5 @@ // SPDX-License-Identifier: MIT -// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. #pragma once diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index 3c68147112..d6ca98e7b4 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -29,7 +29,8 @@ struct BlockReduce2d sweep_tile( [&](auto... idx_) { constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]); - y_tensor(idx_0) = reduce_func(y_tensor(idx_0), x_tensor[idx_]...); + y_tensor(idx_0) = reduce_func( + y_tensor(idx_0), ck_tile::type_convert(x_tensor[idx_])...); }, ReducePacksPerXDim{}); #if 0 diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp index f0a6cf9603..8d075dc5fa 100644 --- a/include/ck_tile/ops/rmsnorm2d.hpp +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -4,7 +4,6 @@ #pragma once #include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp" -#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" #include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp index 99084a25e4..fd89cc36c7 100644 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -11,11 +11,11 @@ namespace ck_tile { // host side args struct Rmsnorm2dFwdHostArgs { - const void* p_x; - const void* p_gamma; + const void* p_x; // [m ,n], input, fp16/bf16 + const void* p_gamma; // [1, n], gamma, prec same as input - void* p_y; - void* p_invRms; + void* p_y; // [m, n], output, fp16/bf16 + void* p_invRms; // [m, 1], output inv-rms, prec same as input, nullptr if not used float epsilon; @@ -83,7 +83,7 @@ struct Rmsnorm2dFwd CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) { - return (hargs.m + Block_M - 1) / Block_M; + return dim3(integer_divide_ceil(hargs.m, Block_M)); } CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } @@ -149,7 +149,7 @@ struct Rmsnorm2dFwd number<1>{}); const auto tmp2_ = - pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); return make_tile_window(tmp2_, make_tuple(number{}), {0}); }(); diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp deleted file mode 100644 index fc4b9f470c..0000000000 --- a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp +++ /dev/null @@ -1,78 +0,0 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. - -#pragma once - -#include "ck_tile/core.hpp" - -namespace ck_tile { -/* -// clang-format off - -4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector - - Block_N (Warp_N * WarpPerBlock_N * Repeat_N ) - +<----------------------< Repeat_N(2)>--------------------->+ - | | - +<-- -->+ - Warp_N - +--------------+--------------+--------------+--------------+----+----------------+ - Warp_M | wrap_0 | wrap_1 | | ^ ^ - +--------------+--------------+ | | - | wrap_2 | wrap_3 | | v - +--------------+--------------+--------------+--------------+----+ Block_M - | | | - + + | - | | | v - +--------------+--------------+--------------+--------------+ + - - each Warp-tile (e.g 16 thrd per row) - - Vector_N (contiguous pixels each thrd holds along N, or vector size) - +-----------+-----------+-----------+-----------+-----------+ - | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M - +-----------+-----------+-----------+-----------+-----------+ - | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ... - +-----------+-----------+-----------+-----------+-----------+ -// clang-format on -*/ -template - typename WarpPerBlock_, // num warps along seq - typename WarpTile_, // warp size, seq - typename Vector_, // contiguous pixels(vector size) along seq - index_t BlockSize_ = - warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> -struct Rmsnorm2dShape -{ - // block size - static constexpr index_t Block_M = BlockTile_::at(number<0>{}); - static constexpr index_t Block_N = BlockTile_::at(number<1>{}); - - // num warps along seq, within each block - static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); - static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); - - // warp size - static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); - static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); - - static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); - static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); - // repeat of each thread along seq - static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); - static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); - - // vector size along seq - static constexpr index_t Vector_M = Vector_::at(number<0>{}); - static constexpr index_t Vector_N = Vector_::at(number<1>{}); - - static_assert(Warp_M % Vector_M == 0); - static_assert(Warp_N % Vector_N == 0); - // num of threads along seq, within each warp - static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; - - static constexpr index_t BlockSize = BlockSize_; -}; - -} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp index e4814cf455..b258dcbae1 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp @@ -26,6 +26,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy sequence<1, 1, 2, 2>, sequence<0, 3, 0, 3>>{}); } + template CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution() { diff --git a/include/ck_tile/ops/smoothquant.hpp b/include/ck_tile/ops/smoothquant.hpp new file mode 100644 index 0000000000..c9e4597657 --- /dev/null +++ b/include/ck_tile/ops/smoothquant.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp" +#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp" +#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp" +#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp" +#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/generic_2d_block_shape.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp new file mode 100644 index 0000000000..6ec3335168 --- /dev/null +++ b/include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp @@ -0,0 +1,176 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +// host side args +struct SmoothquantHostArgs +{ + const void* p_x; // [m ,n], input, fp16/bf16 + const void* p_xscale; // [1, n], input, columnwise scale, fp32 + + void* p_yscale; // [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_xscale) + void* p_qy; // [m, n], output, p_x * p_xscale / p_yscale + + index_t m; + index_t n; + index_t stride; // row_stride +}; + +// TODO: Extract some type to wrapper class +template +struct Smoothquant +{ + using Pipeline = remove_cvref_t; + using Problem = typename Pipeline::Problem; + + using XDataType = remove_cvref_t; + using XScaleDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using QYDataType = remove_cvref_t; + + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kTwoPass = Problem::kTwoPass; + + static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; + static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; + static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + struct Kargs + { + const void* p_x; + const void* p_xscale; + + void* p_yscale; + void* p_qy; + + index_t m; + index_t n; + index_t stride; // row_stride + }; + using Hargs = SmoothquantHostArgs; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) + { + return Kargs{ + hargs.p_x, hargs.p_xscale, hargs.p_yscale, hargs.p_qy, hargs.m, hargs.n, hargs.stride}; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) + { + return dim3(integer_divide_ceil(hargs.m, Block_M)); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + // in byte + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + using S_ = typename Problem::BlockShape; + auto surfix = [&] () { + std::string n; + if (kPadN) n += "_pn"; + if (kTwoPass) n += "_2p"; + return n; }(); + + #define _SS_ std::string + #define _TS_ std::to_string + return _SS_("smoothquant_fwd_") + _SS_(t2s::name) + "_" + + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + + _SS_(Pipeline::name) + surfix; + #undef _SS_ + #undef _TS_ + // clang-format on + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const auto iM = get_block_id() * Block_M; + + const auto x_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_x), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + const auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + const auto xscale_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_xscale), + make_tuple(kargs.n), + make_tuple(1), + number{}, + number<1>{}); + + const auto tmp2_ = + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + + return make_tile_window(tmp2_, make_tuple(number{}), {0}); + }(); + + auto yscale_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_yscale), + make_tuple(kargs.m), + make_tuple(1), + number<1>{}); + + const auto tmp2_ = + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + + return make_tile_window(tmp2_, make_tuple(number{}), {iM}); + }(); + + auto qy_window = [&]() { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_qy), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + __shared__ char smem[GetSmemSize()]; + + Pipeline{}(x_window, xscale_window, yscale_window, qy_window, kargs.n, smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp new file mode 100644 index 0000000000..ff81e69f0c --- /dev/null +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d.hpp" + +namespace ck_tile { + +struct SmoothquantPipelineDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_DEVICE static constexpr auto MakeXScaleBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + if constexpr(Problem::kNeedCrossWarpSync) + { + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = + decltype(make_static_distributed_tensor( + MakeXBlockTileDistribution())); + using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile()); + + return GetBlockReduce2dCrossWarpSync().template GetSmemSize(); + } + else + { + return 1; // zero size arrays are an extension + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp new file mode 100644 index 0000000000..d5b3780dea --- /dev/null +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct SmoothquantPipelineOnePass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using XScaleDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using QYDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_op"; // block per row + else + return "wpr_op"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const XScaleWindow& xscale_window_, + YScaleWindow& yscale_window, + QYWindow& qy_window, + ck_tile::index_t, + void* smem) const + { + auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + auto xscale_window = make_tile_window( + xscale_window_, Policy::template MakeXScaleBlockTileDistribution()); + + auto reduce_absmax_func = ReduceOp::AbsMax{}; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + const auto x = load_tile(x_window); + const auto xscale = load_tile(xscale_window); + auto y = tile_elementwise_in( + [&](const auto& a, const auto& b) { + return type_convert(a) * type_convert(b); + }, + x, + xscale); + + // compute absmax, cross-lane->cross-warp + auto absmax = block_reduce2d( + y, reduce_absmax_func.GetIdentityValue(), reduce_absmax_func); + block_reduce2d_sync(absmax, reduce_max_func); + block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); + + // ex: yscale = absmax / 127 if int8 + auto yscale = tile_elementwise_in( + [&](const auto& v_) { + return v_ / type_convert(numeric::max()); + }, + absmax); + store_tile(yscale_window, cast_tile(yscale)); + + // quantize y to qy + auto qy = make_static_distributed_tensor(y.get_tile_distribution()); + sweep_tile(qy, [&](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + auto qy_ = y[idx] / yscale[i_idx]; + qy(idx) = saturates{}(qy_); + }); + store_tile(qy_window, qy); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp new file mode 100644 index 0000000000..37e09b58cf --- /dev/null +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_problem.hpp @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// Y = X * XScale, QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale) +template +struct SmoothquantPipelineProblem +{ + using XDataType = remove_cvref_t; + using XScaleDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using QYDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp new file mode 100644 index 0000000000..7878ef1d34 --- /dev/null +++ b/include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_two_pass.hpp @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct SmoothquantPipelineTwoPass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using XScaleDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using QYDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockSmoothquantProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_tp"; // block per row + else + return "wpr_tp"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const XScaleWindow& xscale_window_, + YScaleWindow& yscale_window, + QYWindow& qy_window, + ck_tile::index_t row_size, + void* smem) const + { + auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + auto xscale_window = make_tile_window( + xscale_window_, Policy::template MakeXScaleBlockTileDistribution()); + + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + + auto reduce_absmax_func = ReduceOp::AbsMax{}; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + using XTensorType = decltype(cast_tile(load_tile(x_window))); + auto absmax = block_reduce2d.template MakeYBlockTile(); + set_tile(absmax, reduce_absmax_func.GetIdentityValue()); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + const auto xscale = load_tile(xscale_window); + const auto y = tile_elementwise_in( + [&](const auto& a, const auto& b) { + return type_convert(a) * type_convert(b); + }, + x, + xscale); + + block_reduce2d(y, absmax, reduce_absmax_func); + + move_tile_window(x_window, {0, Block_N}); + move_tile_window(xscale_window, {Block_N}); + } + + // compute absmax, cross-lane->cross-warp + block_reduce2d_sync(absmax, reduce_max_func); + block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); + + // ex: yscale = absmax / 127 if int8 + auto yscale = tile_elementwise_in( + [&](const auto& v_) { + return v_ / type_convert(numeric::max()); + }, + absmax); + store_tile(yscale_window, cast_tile(yscale)); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = + row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(xscale_window, {-Block_N}); + move_tile_window(qy_window, {0, stride_to_right_most_window}); + + // recompute y and quantize y to qy + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + const auto xscale = load_tile(xscale_window); + const auto y = tile_elementwise_in( + [&](const auto& a, const auto& b) { + return type_convert(a) * type_convert(b); + }, + x, + xscale); + + auto qy = make_static_distributed_tensor(y.get_tile_distribution()); + sweep_tile(qy, [&](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + auto qy_ = y[idx] / yscale[i_idx]; + qy(idx) = saturates{}(qy_); + }); + store_tile(qy_window, qy); + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(xscale_window, {0, -Block_N}); + move_tile_window(qy_window, {0, -Block_N}); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/remod.py b/include/ck_tile/remod.py index 0612d4238d..b0d2c36efe 100644 --- a/include/ck_tile/remod.py +++ b/include/ck_tile/remod.py @@ -1,3 +1,4 @@ +from datetime import datetime import pathlib from pathlib import Path import subprocess @@ -8,8 +9,8 @@ NS = 'ck_tile' OPS = 'ops' OPS_COMMON = 'common' # common header will be duplicated into ops/* other module -HEADER_COMMON = """// SPDX-License-Identifier: MIT -// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.\n +HEADER_COMMON = f"""// SPDX-License-Identifier: MIT +// Copyright (c) 2018-{datetime.now().year}, Advanced Micro Devices, Inc. All rights reserved.\n """ # aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)