From 92b701bb162b944b1d87078a006777eb4fbb9434 Mon Sep 17 00:00:00 2001 From: rocking Date: Wed, 30 Oct 2024 15:22:56 +0800 Subject: [PATCH] [Ck tile] support rmsnorm and related fusion (#1605) * Add reduce2d new api * Prevent user use cross warp reduction * Fix bug of std caculation * Add rmsnorm2d * Add rmsnorm small example * Remove static assert to prevent compile fail * Add script to test performance and correctness * Add missing cmake change * refine naming * refine example of rmsnorm * Fix bug of rmsnorm * Refine naming * Fix cmake * clang format * Refine pipeline name * Add add_rmsnorm2d_rdquant kernel * Add reduce op * host verification * Fix bug of one pass pipeline * Refine tile size * Add two pass pipeline * Rename two pass to three pass * Fix bug of kSaveX == false * Add instance library * Add test script * Fix bug of x verification * Add save_x to trait * Add README * Move reduce2d into reduce folder * Fix bug of welford when number of m warp > 1 * remove reduncant comment * 1. move 06_rmsnorm2d to 10_rmsnorm2d 2. move 07_add_rmsnorm2d_rdquant to 11_add_rmsnorm2d_rdquant * clang format and add missing header * Add host validation of add + layernorm2d + rsquant * Revert "Add host validation of add + layernorm2d + rsquant" This reverts commit 936cb457978b928b90eff89a08fcdb7dc8bbed67. * Remove deprecated flag [ROCm/composable_kernel commit: 3d60953477bd575e320c84240a9f8ef49eb7bedd] --- example/ck_tile/05_reduce/reduce.cpp | 65 ++-- example/ck_tile/05_reduce/reduce.hpp | 186 +++++++----- example/ck_tile/10_rmsnorm2d/CMakeLists.txt | 25 ++ example/ck_tile/10_rmsnorm2d/README.md | 22 ++ .../10_rmsnorm2d/example_rmsnorm2d_fwd.cpp | 165 +++++++++++ .../instances/rmsnorm2d_fwd_api.cpp | 153 ++++++++++ .../rmsnorm2d_fwd_bf16_n1024_instance.cpp | 22 ++ .../rmsnorm2d_fwd_bf16_n1536_instance.cpp | 13 + .../rmsnorm2d_fwd_bf16_n2048_instance.cpp | 14 + .../rmsnorm2d_fwd_bf16_n256_instance.cpp | 12 + .../rmsnorm2d_fwd_bf16_n3072_instance.cpp | 14 + .../rmsnorm2d_fwd_bf16_n4096_instance.cpp | 14 + .../rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp | 14 + .../rmsnorm2d_fwd_bf16_n512_instance.cpp | 13 + .../rmsnorm2d_fwd_bf16_n64_n128_instance.cpp | 12 + .../rmsnorm2d_fwd_bf16_n768_instance.cpp | 12 + .../rmsnorm2d_fwd_fp16_n1024_instance.cpp | 22 ++ .../rmsnorm2d_fwd_fp16_n1536_instance.cpp | 13 + .../rmsnorm2d_fwd_fp16_n2048_instance.cpp | 14 + .../rmsnorm2d_fwd_fp16_n256_instance.cpp | 12 + .../rmsnorm2d_fwd_fp16_n3072_instance.cpp | 14 + .../rmsnorm2d_fwd_fp16_n4096_instance.cpp | 14 + .../rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp | 14 + .../rmsnorm2d_fwd_fp16_n512_instance.cpp | 13 + .../rmsnorm2d_fwd_fp16_n64_n128_instance.cpp | 12 + .../rmsnorm2d_fwd_fp16_n768_instance.cpp | 12 + .../rmsnorm2d_fwd_instance_common.hpp | 65 ++++ .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 179 +++++++++++ .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp | 117 ++++++++ .../ck_tile/10_rmsnorm2d/script/perf_test.sh | 38 +++ .../ck_tile/10_rmsnorm2d/script/smoke_test.sh | 31 ++ .../11_add_rmsnorm2d_rdquant/CMakeLists.txt | 25 ++ .../11_add_rmsnorm2d_rdquant/README.md | 22 ++ .../add_rmsnorm2d_rdquant_fwd.cpp | 279 +++++++++++++++++ .../add_rmsnorm2d_rdquant_fwd.hpp | 123 ++++++++ .../example_add_rmsnorm2d_rdquant_fwd.cpp | 280 ++++++++++++++++++ .../add_rmsnorm2d_rdquant_fwd_api.cpp | 157 ++++++++++ ...norm2d_rdquant_fwd_bf16_n1024_instance.cpp | 22 ++ ...norm2d_rdquant_fwd_bf16_n1536_instance.cpp | 13 + ...norm2d_rdquant_fwd_bf16_n2048_instance.cpp | 14 + ...snorm2d_rdquant_fwd_bf16_n256_instance.cpp | 12 + ...norm2d_rdquant_fwd_bf16_n3072_instance.cpp | 14 + ...norm2d_rdquant_fwd_bf16_n4096_instance.cpp | 14 + ...m2d_rdquant_fwd_bf16_n4096_tp_instance.cpp | 14 + ...snorm2d_rdquant_fwd_bf16_n512_instance.cpp | 13 + ...m2d_rdquant_fwd_bf16_n64_n128_instance.cpp | 12 + ...snorm2d_rdquant_fwd_bf16_n768_instance.cpp | 12 + ...norm2d_rdquant_fwd_fp16_n1024_instance.cpp | 22 ++ ...norm2d_rdquant_fwd_fp16_n1536_instance.cpp | 13 + ...norm2d_rdquant_fwd_fp16_n2048_instance.cpp | 14 + ...snorm2d_rdquant_fwd_fp16_n256_instance.cpp | 12 + ...norm2d_rdquant_fwd_fp16_n3072_instance.cpp | 14 + ...norm2d_rdquant_fwd_fp16_n4096_instance.cpp | 14 + ...m2d_rdquant_fwd_fp16_n4096_tp_instance.cpp | 14 + ...snorm2d_rdquant_fwd_fp16_n512_instance.cpp | 13 + ...m2d_rdquant_fwd_fp16_n64_n128_instance.cpp | 12 + ...snorm2d_rdquant_fwd_fp16_n768_instance.cpp | 12 + ..._rmsnorm2d_rdquant_fwd_instance_common.hpp | 67 +++++ .../script/perf_test.sh | 38 +++ .../script/smoke_test.sh | 31 ++ example/ck_tile/CMakeLists.txt | 3 +- include/ck_tile/core.hpp | 1 + .../ck_tile/core/utility/reduce_operator.hpp | 95 ++++++ include/ck_tile/host.hpp | 3 + .../host/reference/reference_elementwise.hpp | 47 +++ .../host/reference/reference_reduce.hpp | 17 +- .../reference/reference_rmsnorm2d_fwd.hpp | 52 ++++ .../reference_rowwise_quantization2d.hpp | 33 +++ include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp | 12 + .../add_rmsnorm2d_rdquant_fwd_kernel.hpp | 239 +++++++++++++++ .../add_rmsnorm2d_rdquant_fwd_shape.hpp | 78 +++++ ...2d_rdquant_fwd_pipeline_default_policy.hpp | 94 ++++++ ...msnorm2d_rdquant_fwd_pipeline_one_pass.hpp | 142 +++++++++ ...rmsnorm2d_rdquant_fwd_pipeline_problem.hpp | 41 +++ ...norm2d_rdquant_fwd_pipeline_three_pass.hpp | 266 +++++++++++++++++ .../layernorm2d_fwd_pipeline_one_pass.hpp | 4 +- .../layernorm2d_fwd_pipeline_two_pass.hpp | 6 +- include/ck_tile/ops/reduce.hpp | 3 + .../ck_tile/ops/reduce/block/block_reduce.hpp | 19 +- .../ops/reduce/block/block_reduce2d.hpp | 260 ++++++++++++++++ .../block/block_reduce2d_default_policy.hpp | 79 +++++ .../reduce/block/block_reduce2d_problem.hpp | 18 ++ include/ck_tile/ops/rmsnorm2d.hpp | 12 + .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 202 +++++++++++++ .../rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp | 78 +++++ .../rmsnorm2d_fwd_pipeline_default_policy.hpp | 94 ++++++ .../rmsnorm2d_fwd_pipeline_one_pass.hpp | 101 +++++++ .../rmsnorm2d_fwd_pipeline_problem.hpp | 36 +++ .../rmsnorm2d_fwd_pipeline_two_pass.hpp | 131 ++++++++ .../ops/welford/block/block_welford.hpp | 8 +- 90 files changed, 4674 insertions(+), 128 deletions(-) create mode 100644 example/ck_tile/10_rmsnorm2d/CMakeLists.txt create mode 100644 example/ck_tile/10_rmsnorm2d/README.md create mode 100644 example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp create mode 100644 example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp create mode 100644 example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp create mode 100755 example/ck_tile/10_rmsnorm2d/script/perf_test.sh create mode 100755 example/ck_tile/10_rmsnorm2d/script/smoke_test.sh create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/README.md create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp create mode 100644 example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp create mode 100755 example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh create mode 100755 example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh create mode 100644 include/ck_tile/core/utility/reduce_operator.hpp create mode 100644 include/ck_tile/host/reference/reference_elementwise.hpp create mode 100644 include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp create mode 100644 include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp create mode 100644 include/ck_tile/ops/reduce/block/block_reduce2d.hpp create mode 100644 include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp create mode 100644 include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp diff --git a/example/ck_tile/05_reduce/reduce.cpp b/example/ck_tile/05_reduce/reduce.cpp index 7973a8dfdb..005541dc62 100644 --- a/example/ck_tile/05_reduce/reduce.cpp +++ b/example/ck_tile/05_reduce/reduce.cpp @@ -19,9 +19,9 @@ auto create_args(int argc, char* argv[]) template bool run(const ck_tile::ArgParser& arg_parser) { - using ADataType = DataType; - using AccDataType = float; - using BDataType = DataType; + using XDataType = DataType; + using ComputeDataType = float; + using YDataType = DataType; ck_tile::index_t m = arg_parser.get_int("m"); ck_tile::index_t n = arg_parser.get_int("n"); @@ -29,35 +29,39 @@ bool run(const ck_tile::ArgParser& arg_parser) int warmup = arg_parser.get_int("warmup"); int repeat = arg_parser.get_int("repeat"); - ck_tile::HostTensor a_host({m, n}); - ck_tile::HostTensor b_host_ref({m}); - ck_tile::HostTensor b_host_dev({m}); + ck_tile::HostTensor x_host({m, n}); + ck_tile::HostTensor y_host_ref({m}); + ck_tile::HostTensor y_host_dev({m}); - ck_tile::FillUniformDistribution{-5.f, 5.f}(a_host); + ck_tile::FillUniformDistribution{-5.f, 5.f}(x_host); - ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); - ck_tile::DeviceMem b_buf(b_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); - a_buf.ToDevice(a_host.data()); + x_buf.ToDevice(x_host.data()); + using ReduceOp = ck_tile::ReduceOp::Add; using BlockWarps = ck_tile::sequence<4, 1>; using BlockTile = ck_tile::sequence<128, 128>; using WarpTile = ck_tile::sequence<32, 128>; - using ThreadTile = ck_tile::sequence<8, 8>; + using Vector = ck_tile::sequence<8, 8>; - constexpr ck_tile::index_t kBlockSize = 256; + // cross warp-reduce + // using BlockWarps = ck_tile::sequence<2, 2>; + // using BlockTile = ck_tile::sequence<2, 1024>; + // using WarpTile = ck_tile::sequence<1, 512>; + // using Vector = ck_tile::sequence<1, 8>; + + constexpr ck_tile::index_t kBlockSize = 512; constexpr ck_tile::index_t kBlockPerCu = 1; ck_tile::index_t kGridSize = (m / BlockTile::at(ck_tile::number<0>{})); std::cout << "grid size " << kGridSize << std::endl; - using Kernel = ck_tile::Reduce; + using Shape = ck_tile::Reduce2dShape; + using Porblem = + ck_tile::Reduce2dProblem; + + using Kernel = ck_tile::Reduce; float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat}, ck_tile::make_kernel( @@ -65,12 +69,12 @@ bool run(const ck_tile::ArgParser& arg_parser) kGridSize, kBlockSize, 0, - static_cast(a_buf.GetDeviceBuffer()), - static_cast(b_buf.GetDeviceBuffer()), + static_cast(x_buf.GetDeviceBuffer()), + static_cast(y_buf.GetDeviceBuffer()), m, n)); - std::size_t num_btype = sizeof(ADataType) * m * n + sizeof(BDataType) * m; + std::size_t num_btype = sizeof(XDataType) * m * n + sizeof(YDataType) * m; float gb_per_sec = num_btype / 1.E6 / ave_time; @@ -81,9 +85,10 @@ bool run(const ck_tile::ArgParser& arg_parser) if(do_validation) { // reference - ck_tile::reference_reduce(a_host, b_host_ref); - b_buf.FromDevice(b_host_dev.mData.data()); - pass = ck_tile::check_err(b_host_dev, b_host_ref); + ck_tile::reference_reduce( + x_host, y_host_ref, ReduceOp{}); + y_buf.FromDevice(y_host_dev.mData.data()); + pass = ck_tile::check_err(y_host_dev, y_host_ref); std::cout << "valid:" << (pass ? "y" : "n") << std::flush << std::endl; } @@ -103,8 +108,8 @@ int main(int argc, char* argv[]) { return run(arg_parser) ? 0 : -2; } - if(data_type == "bf16") - { - return run(arg_parser) ? 0 : -2; - } + // else if(data_type == "bf16") + // { + // return run(arg_parser) ? 0 : -2; + // } } diff --git a/example/ck_tile/05_reduce/reduce.hpp b/example/ck_tile/05_reduce/reduce.hpp index e36b468951..55e479591c 100644 --- a/example/ck_tile/05_reduce/reduce.hpp +++ b/example/ck_tile/05_reduce/reduce.hpp @@ -5,20 +5,16 @@ #include "ck_tile/core.hpp" #include "ck_tile/ops/common.hpp" - #include "ck_tile/ops/reduce/block/block_reduce.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp" namespace ck_tile { -template +template typename BlockTile, // block size, seq typename WarpTile, // warp size, seq - typename ThreadTile> // contiguous pixels(vector size) along seq -struct Reduce + typename Vector> // contiguous pixels(vector size) along seq +struct Reduce2dShape { static constexpr index_t Block_M = BlockTile::at(number<0>{}); static constexpr index_t Block_N = BlockTile::at(number<1>{}); @@ -26,93 +22,143 @@ struct Reduce static constexpr index_t Warp_M = WarpTile::at(number<0>{}); static constexpr index_t Warp_N = WarpTile::at(number<1>{}); - static constexpr index_t Thread_M = ThreadTile::at(number<0>{}); - static constexpr index_t Thread_N = ThreadTile::at(number<1>{}); + static constexpr index_t Vector_M = Vector::at(number<0>{}); + static constexpr index_t Vector_N = Vector::at(number<1>{}); static constexpr index_t WarpPerBlock_M = BlockWarps::at(number<0>{}); static constexpr index_t WarpPerBlock_N = BlockWarps::at(number<1>{}); - static constexpr index_t ThreadPerWarp_M = Warp_M / Thread_M; - static constexpr index_t ThreadPerWarp_N = Warp_N / Thread_N; + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); - __device__ static constexpr auto MakeABlockTileDistribution() + static constexpr index_t BlockSize = + warpSize * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{}); +}; + +template +struct Reduce2dProblem +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + using ReduceOp = ReduceOp_; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; +}; + +template +struct Reduce +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + +#if 0 + CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) + const { - return make_static_tile_distribution( - tile_distribution_encoding< - sequence<>, - tuple, - sequence>, - tuple, sequence<1, 2>>, - tuple, sequence<2, 2>>, - sequence<1, 1, 2, 2>, - sequence<0, 3, 0, 3>>{}); - } + using S = typename Problem::BlockShape; - __device__ void operator()(const ADataType* p_a, BDataType* p_b, index_t M, index_t N) const - { - const auto a_m_n = make_naive_tensor_view( - p_a, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); - const auto iM = get_block_id() * Block_M; + const auto y_m = make_naive_tensor_view_packed( + p_y, make_tuple(M), number<1>{}); - // A window - auto a_block_window = make_tile_window(a_m_n, - make_tuple(number{}, number{}), - {iM, 0}, - MakeABlockTileDistribution()); + const auto iM = get_block_id() * S::Block_M; + + auto x_window = make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + + auto y_window = make_tile_window(y_m, make_tuple(number{}), {iM}); const auto f_reduce = [](const auto& v0, const auto& v1) { return v0 + v1; }; - const ADataType reduce_init_value = 0; + const XDataType reduce_init_value = 0; constexpr auto reduce_dims = sequence<1>{}; - // Acc tile - // TODO: support cross warp reduction - auto acc_block_tensor = decltype(block_tile_reduce( - load_tile(a_block_window), reduce_dims, f_reduce, reduce_init_value)){}; + auto y_compute = decltype(block_tile_reduce( + load_tile(x_window), reduce_dims, f_reduce, reduce_init_value)){}; - // init Acc tile - tile_elementwise_inout( - [&](auto& acc) { acc = type_convert(reduce_init_value); }, - acc_block_tensor); + set_tile(y_compute, reduce_init_value); - // loop - index_t iN = 0; + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); - do + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) { - const auto a_block_tensor = load_tile(a_block_window); + const auto x = load_tile(x_window); + block_tile_reduce(y_compute, x, reduce_dims, f_reduce); + move_tile_window(x_window, {0, S::Block_N}); + } - // FIXME: support cross warp reduction - block_tile_reduce(acc_block_tensor, a_block_tensor, reduce_dims, f_reduce); + block_tile_reduce_sync(y_compute, f_reduce); - move_tile_window(a_block_window, {0, Block_N}); - - iN += Block_N; - - } while(iN < N); - - // FIXME: support cross warp reduction - block_tile_reduce_sync(acc_block_tensor, f_reduce); - - // convert acc_block_tensor to b_block_tensor - const auto b_block_tensor = tile_elementwise_in( - [](const auto& acc) { return type_convert(acc); }, acc_block_tensor); - - // B - const auto b_m = make_naive_tensor_view_packed( - p_b, make_tuple(M), number<32>{}); - - // B window - auto b_block_window = make_tile_window(b_m, make_tuple(number{}), {iM}); - - // store B tile - store_tile(b_block_window, b_block_tensor); + store_tile(y_window, cast_tile(y_compute)); } +#else + CK_TILE_DEVICE void operator()(const XDataType* p_x, YDataType* p_y, index_t M, index_t N) const + { + using S = typename Problem::BlockShape; + + const auto x_m_n = make_naive_tensor_view( + p_x, make_tuple(M, N), make_tuple(N, 1), number{}, number<1>{}); + + const auto y_m = make_naive_tensor_view_packed( + p_y, make_tuple(M), number<1>{}); + + const auto iM = get_block_id() * S::Block_M; + + auto x_window = make_tile_window(x_m_n, + make_tuple(number{}, number{}), + {iM, 0}, + Policy::template MakeXBlockTileDistribution()); + + auto y_window = make_tile_window(y_m, make_tuple(number{}), {iM}); + + __shared__ char smem[Policy::template GetSmemSize()]; + + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(N, S::Block_N)); + + auto reduce_func = typename Problem::ReduceOp{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + using XTensorType = decltype(load_tile(x_window)); + auto y_compute = block_reduce2d.template MakeYBlockTile(); + set_tile(y_compute, reduce_func.template GetIdentityValue()); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + block_reduce2d(x, y_compute, reduce_func); + move_tile_window(x_window, {0, S::Block_N}); + } + + block_reduce2d_sync(y_compute, reduce_func); + block_reduce2d_cross_warp_sync(y_compute, smem, reduce_func); + + store_tile(y_window, cast_tile(y_compute)); + } +#endif }; } // namespace ck_tile diff --git a/example/ck_tile/10_rmsnorm2d/CMakeLists.txt b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt new file mode 100644 index 0000000000..a3ff8fdf45 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/CMakeLists.txt @@ -0,0 +1,25 @@ +set(TILE_RMSNORM2D_FWD "tile_rmsnorm2d_fwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding ${TILE_RMSNORM2D_FWD}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${TILE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) +target_include_directories(${TILE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${TILE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_RMSNORM2D_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${TILE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) + +set(EXAMPLE_RMSNORM2D_FWD "tile_example_rmsnorm2d_fwd") +add_executable(${EXAMPLE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL example_rmsnorm2d_fwd.cpp) +target_compile_options(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${TILE_RMSNORM2D_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/10_rmsnorm2d/README.md b/example/ck_tile/10_rmsnorm2d/README.md new file mode 100644 index 0000000000..c067496477 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/README.md @@ -0,0 +1,22 @@ +# Rmsnorm2D forward + +This folder contains example for Rmsnorm2D forward using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_rmsnorm2d_fwd -j +``` +This will result in an executable `build/bin/tile_rmsnorm2d_fwd` + +## cmdline +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` diff --git a/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp new file mode 100644 index 0000000000..bb2c949015 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/example_rmsnorm2d_fwd.cpp @@ -0,0 +1,165 @@ +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/rmsnorm2d.hpp" +#include + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using XDataType = DataType; + using YDataType = DataType; + using GammaDataType = DataType; + using InvRmsDataType = ck_tile::null_type; + + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor invRms_host_ref({m}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + constexpr bool kTwoPass = true; + + using BlockWarps = ck_tile::sequence<2, 2>; + using BlockTile = ck_tile::sequence<2, 128>; + using WarpTile = ck_tile::sequence<1, 64>; + using Vector = ck_tile::sequence<1, 1>; + + using Shape = ck_tile::Rmsnorm2dShape; + using Problem = ck_tile::Rmsnorm2dFwdPipelineProblem; + + using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + using Kernel = ck_tile::Rmsnorm2dFwd; + + ck_tile::Rmsnorm2dFwdHostArgs args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), + nullptr, + epsilon, + m, + n, + stride}; + + auto kargs = Kernel::MakeKargs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + auto [rtol, atol] = ck_tile::make_tuple(1e-3, 1e-3); + if(stride == n) + { + pass = ck_tile::check_err( + y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, + y_host_dev.begin() + i_r * stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, + y_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(y_host_dev_row, + y_host_ref_row, + std::string("OUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp new file mode 100644 index 0000000000..f9cfe72ded --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" + +template +using trait_ = rmsnorm2d_fwd_traits_; + +template +float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/, + rmsnorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ +#if 1 + float r = -1; + // clang-format off + // rm rn tm tn vn pd rms 2p + if(a.n <= 64) { + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + return r; +#else + return rmsnorm2d_fwd_>(s, a); +#endif + // clang-format on +} + +float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s) +{ + + float r = -1; + if(t.data_type.compare("fp16") == 0) + { + return rmsnorm2d_fwd_b16_(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return rmsnorm2d_fwd_b16_(t, a, s); + } + if(r < 0) + throw std::runtime_error("Without supported instances!"); + + return r; +} diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp new file mode 100644 index 0000000000..5e2a35f9e8 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +#if 0 +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +template float rmsnorm2d_fwd_>(const S&, A); +#endif + +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp new file mode 100644 index 0000000000..8c734806e1 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp new file mode 100644 index 0000000000..9222001433 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp new file mode 100644 index 0000000000..ed33c84923 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp new file mode 100644 index 0000000000..b753bbc345 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp new file mode 100644 index 0000000000..27cb9bdf3d --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..23afb5672b --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp new file mode 100644 index 0000000000..b428f58051 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000..3001106697 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp new file mode 100644 index 0000000000..e9c8d6a1d4 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp new file mode 100644 index 0000000000..15198eebe6 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +#if 0 +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +template float rmsnorm2d_fwd_>(const S&, A); +#endif + +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp new file mode 100644 index 0000000000..8ac85fa9b5 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp new file mode 100644 index 0000000000..10e8fafc2f --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp new file mode 100644 index 0000000000..4e1a80bf64 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp new file mode 100644 index 0000000000..45e56a92b8 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp new file mode 100644 index 0000000000..35401f6f82 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..1e3700fad3 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp new file mode 100644 index 0000000000..cdc4d00bd2 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000..ec80c2ee4a --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp new file mode 100644 index 0000000000..ddfc5a54e8 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp new file mode 100644 index 0000000000..8f6ff84b64 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp @@ -0,0 +1,65 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = rmsnorm2d_fwd_args; + +template +using trait_ = rmsnorm2d_fwd_traits_; + +template +float rmsnorm2d_fwd_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = + ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, + typename RmsnormTypeConfig::GammaDataType, + typename RmsnormTypeConfig::ComputeDataType, + typename RmsnormTypeConfig::YDataType, + typename RmsnormTypeConfig::InvRmsDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kSaveInvRms, + Traits_::kTwoPass>; + + using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::Rmsnorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp new file mode 100644 index 0000000000..698a8b43eb --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -0,0 +1,179 @@ +#include "ck_tile/host.hpp" +#include "rmsnorm2d_fwd.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = RmsnormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + + using InvRmsDataType = + std::conditional_t; + + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor invRms_host_ref({m}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + rmsnorm2d_fwd_traits traits{data_type, SaveRms}; + + rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), + nullptr, + epsilon, + m, + n, + stride}; + + float ave_time = rmsnorm2d_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = + sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err( + y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, + y_host_dev.begin() + i_r * stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, + y_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(y_host_dev_row, + y_host_ref_row, + std::string("OUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + int save_rms = arg_parser.get_int("save_rms"); + if(data_type == "fp16" && save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp16" && !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp new file mode 100644 index 0000000000..756ecb2c4b --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/rmsnorm2d.hpp" +#include + +template +struct RmsnormTypeConfig; + +template <> +struct RmsnormTypeConfig +{ + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using InvRmsDataType = ck_tile::half_t; + using ComputeDataType = float; +}; + +template <> +struct RmsnormTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using YDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using InvRmsDataType = ck_tile::bf16_t; + using ComputeDataType = float; +}; + +// runtime args +struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct rmsnorm2d_fwd_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Rmsnorm2dShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +template +float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); + +// This is the public API, will be generated by script +struct rmsnorm2d_fwd_traits +{ + std::string data_type; + bool save_rms; +}; + +float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/10_rmsnorm2d/script/perf_test.sh b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh new file mode 100755 index 0000000000..f3cfcc4b89 --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/script/perf_test.sh @@ -0,0 +1,38 @@ + +# run from top of ck folder +EXE=build/bin/tile_rmsnorm2d_fwd + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh new file mode 100755 index 0000000000..6ec5e846ce --- /dev/null +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -0,0 +1,31 @@ +#!/bin/sh +# call from top of CK folder +EXE=./build/bin/tile_rmsnorm2d_fwd + +for pr_i in "fp16" "bf16" ; do +$EXE -prec=$pr_i -m=99 -n=13 +$EXE -prec=$pr_i -m=17 -n=16 +$EXE -prec=$pr_i -m=1 -n=100 +$EXE -prec=$pr_i -m=4 -n=128 +$EXE -prec=$pr_i -m=80 -n=127 +$EXE -prec=$pr_i -m=22 -n=255 -stride=256 +$EXE -prec=$pr_i -m=7 -n=599 +$EXE -prec=$pr_i -m=19 -n=512 +$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 +$EXE -prec=$pr_i -m=11 -n=510 +$EXE -prec=$pr_i -m=171 -n=676 -stride=818 +$EXE -prec=$pr_i -m=91 -n=636 +$EXE -prec=$pr_i -m=12 -n=768 -stride=800 +$EXE -prec=$pr_i -m=100 -n=766 -stride=812 +$EXE -prec=$pr_i -m=31 -n=1024 +$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 +$EXE -prec=$pr_i -m=8 -n=1501 +$EXE -prec=$pr_i -m=3 -n=1826 +$EXE -prec=$pr_i -m=5 -n=2040 +$EXE -prec=$pr_i -m=7 -n=2734 +$EXE -prec=$pr_i -m=1 -n=3182 +$EXE -prec=$pr_i -m=9 -n=4096 +$EXE -prec=$pr_i -m=3 -n=8192 +$EXE -prec=$pr_i -m=1 -n=10547 +$EXE -prec=$pr_i -m=3 -n=17134 +done diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt new file mode 100644 index 0000000000..6b0c3cef7a --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/CMakeLists.txt @@ -0,0 +1,25 @@ +set(TILE_ADD_RMSNORM2D_RDQUANT_FWD "tile_add_rmsnorm2d_rdquant_fwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding ${TILE_ADD_RMSNORM2D_RDQUANT_FWD}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL add_rmsnorm2d_rdquant_fwd.cpp) +target_include_directories(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${INSTANCE_SRCS}) + +set(TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${TILE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) + +set(EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD "tile_example_add_rmsnorm2d_rdquant_fwd") +add_executable(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} EXCLUDE_FROM_ALL example_add_rmsnorm2d_rdquant_fwd.cpp) +target_compile_options(${EXAMPLE_ADD_RMSNORM2D_RDQUANT_FWD} PRIVATE ${TILE_ADD_RMSNORM2D_RDQUANT_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md b/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md new file mode 100644 index 0000000000..960369b78d --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/README.md @@ -0,0 +1,22 @@ +# Add + Rmsnorm2D + rowwise dynamic quantization forward + +This folder contains example for add + Rmsnorm2D + rowwise dynamic quantization forward using ck_tile tile-programming implementation. Rdquant is short for rowwise dynamic quantization here. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_add_rmsnorm2d_rdquant_fwd -j +``` +This will result in an executable `build/bin/tile_add_rmsnorm2d_rdquant_fwd` + +## cmdline +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp new file mode 100644 index 0000000000..43bc9a6cfe --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.cpp @@ -0,0 +1,279 @@ +#include "ck_tile/host.hpp" +#include "add_rmsnorm2d_rdquant_fwd.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("save_x", "1", "save rms(invrms) or not. set to 1 in training case") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = AddRmsnormRdquantTypeConfig; + + using ADataType = typename TypeConfig::ADataType; + using BDataType = typename TypeConfig::BDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + using XDataType = typename TypeConfig::XDataType; + using YScaleDataType = typename TypeConfig::YScaleDataType; + using QYDataType = typename TypeConfig::QYDataType; + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor a_host({m, n}, {stride, 1}); + ck_tile::HostTensor b_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor x_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor x_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + b_buf.ToDevice(b_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + add_rmsnorm2d_rdquant_fwd_traits traits{data_type, SaveX}; + + add_rmsnorm2d_rdquant_fwd_args args{a_buf.GetDeviceBuffer(), + b_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + x_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + epsilon, + m, + n, + stride}; + + float ave_time = add_rmsnorm2d_rdquant_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = sizeof(ADataType) * m * n + sizeof(BDataType) * m * n + + sizeof(GammaDataType) * n + sizeof(YScaleDataType) * m + + sizeof(QYDataType) * m * n; + + if constexpr(SaveX) + num_byte += sizeof(XDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + using InvRmsDataType = DataType; + + // Add + { + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + ck_tile::reference_binary_elementwise( + a_host, b_host, x_host_ref, op); + + x_buf.FromDevice(x_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + ck_tile::HostTensor y_host({m, n}); + // Rmsnorm2d + { + ck_tile::HostTensor invRms_host_ref({m}); + + // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for + // simplicity + ck_tile::reference_rmsnorm2d_fwd( + x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + int save_x = arg_parser.get_int("save_x"); + if(data_type == "fp16" && save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp16" && !save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && save_x) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && !save_x) + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp new file mode 100644 index 0000000000..bf70d9d23f --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/add_rmsnorm2d_rdquant_fwd.hpp @@ -0,0 +1,123 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" +#include + +template +struct AddRmsnormRdquantTypeConfig; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::half_t; + using BDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using XDataType = ck_tile::half_t; + using YScaleDataType = ck_tile::half_t; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +template <> +struct AddRmsnormRdquantTypeConfig +{ + using ADataType = ck_tile::bf16_t; + using BDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using XDataType = ck_tile::bf16_t; + using YScaleDataType = ck_tile::bf16_t; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; +}; + +// runtime args +struct add_rmsnorm2d_rdquant_fwd_args : public ck_tile::AddRmsnorm2dRdquantFwdHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct add_rmsnorm2d_rdquant_fwd_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::AddRmsnorm2dRdquantShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveX = kSaveX_; + static constexpr bool kThreePass = kThreePass_; +}; + +template +float add_rmsnorm2d_rdquant_fwd_(const ck_tile::stream_config& s, add_rmsnorm2d_rdquant_fwd_args a); + +// This is the public API, will be generated by script +struct add_rmsnorm2d_rdquant_fwd_traits +{ + std::string data_type; + bool save_x; +}; + +float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits, + add_rmsnorm2d_rdquant_fwd_args, + const ck_tile::stream_config&); diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp new file mode 100644 index 0000000000..40fabf7f55 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/example_add_rmsnorm2d_rdquant_fwd.cpp @@ -0,0 +1,280 @@ +#include "ck_tile/host.hpp" +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + // due to rounding, int8 quantization might have 1 abs error + double rtol = 1; + double atol = 1; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("v", "1", "cpu validation or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "0", "cold iter") + .insert("repeat", "1", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using ADataType = DataType; + using BDataType = DataType; + using GammaDataType = DataType; + using XDataType = DataType; + using YScaleDataType = DataType; + using QYDataType = ck_tile::int8_t; + using ComputeDataType = float; + + // host verify + ck_tile::HostTensor a_host({m, n}, {stride, 1}); + ck_tile::HostTensor b_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor x_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor x_host_dev({m, n}, {stride, 1}); + ck_tile::HostTensor yscale_host_ref({m}, {1}); + ck_tile::HostTensor yscale_host_dev({m}, {1}); + ck_tile::HostTensor qy_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor qy_host_dev({m, n}, {stride, 1}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(a_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(b_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem a_buf(a_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem b_buf(b_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem x_buf(x_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem yscale_buf(yscale_host_dev.get_element_space_size_in_bytes()); + ck_tile::DeviceMem qy_buf(qy_host_dev.get_element_space_size_in_bytes()); + + a_buf.ToDevice(a_host.data()); + b_buf.ToDevice(b_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + constexpr bool kThreePass = true; + + using BlockWarps = ck_tile::sequence<2, 2>; + using BlockTile = ck_tile::sequence<2, 128>; + using WarpTile = ck_tile::sequence<1, 64>; + using Vector = ck_tile::sequence<1, 1>; + + using Shape = ck_tile::AddRmsnorm2dRdquantShape; + using Problem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem; + + using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; + using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass; + using Pipeline = std::conditional_t; + using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; + + ck_tile::AddRmsnorm2dRdquantFwdHostArgs args{a_buf.GetDeviceBuffer(), + b_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + x_buf.GetDeviceBuffer(), + yscale_buf.GetDeviceBuffer(), + qy_buf.GetDeviceBuffer(), + epsilon, + m, + n, + stride}; + + auto kargs = Kernel::MakeKargs(args); + + const dim3 grids = Kernel::GridSize(args); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + auto s = ck_tile::stream_config{nullptr, true, 0, warmup, repeat}; + + ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); + + bool pass = true; + + if(do_validation) + { + using YDataType = ComputeDataType; + using InvRmsDataType = DataType; + + // Add + { + auto op = [](const auto& v0, const auto& v1) { return v0 + v1; }; + ck_tile::reference_binary_elementwise( + a_host, b_host, x_host_ref, op); + + x_buf.FromDevice(x_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err( + x_host_dev, x_host_ref, std::string("x Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector x_host_dev_row(x_host_dev.begin() + i_r * stride, + x_host_dev.begin() + i_r * stride + n); + std::vector x_host_ref_row(x_host_ref.begin() + i_r * stride, + x_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(x_host_dev_row, + x_host_ref_row, + std::string("x[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + ck_tile::HostTensor y_host({m, n}); + // Rmsnorm2d + { + ck_tile::HostTensor invRms_host_ref({m}); + + // CAUSION: kernel use ComputeDataType version of x, but we use XDataType here for + // simplicity + ck_tile::reference_rmsnorm2d_fwd( + x_host_ref, gamma_host, y_host, invRms_host_ref, epsilon); + } + + // yscale + { + ck_tile::HostTensor y_rowwise_amax_host({m}); + + using ReduceAmax = ck_tile::ReduceOp::AbsMax; + ck_tile::reference_reduce( + y_host, y_rowwise_amax_host, ReduceAmax{}); + + auto op = [](const auto& v0) { + return v0 / + ck_tile::type_convert(ck_tile::numeric::max()); + }; + ck_tile::reference_unary_elementwise( + y_rowwise_amax_host, yscale_host_ref, op); + + yscale_buf.FromDevice(yscale_host_dev.mData.data()); + + auto [rtol, atol] = get_elimit(); + pass &= ck_tile::check_err(yscale_host_dev, + yscale_host_ref, + std::string("yscale Error: Incorrect results!"), + rtol, + atol); + } + + // rowwise quantization + { + ck_tile::reference_rowwise_quantization2d( + y_host, yscale_host_ref, qy_host_ref); + + qy_buf.FromDevice(qy_host_dev.data()); + auto [rtol, atol] = get_elimit(); + + if(stride == n) + { + pass = ck_tile::check_err(qy_host_dev, + qy_host_ref, + std::string("qy Error: Incorrect results!"), + rtol, + atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector qy_host_dev_row(qy_host_dev.begin() + i_r * stride, + qy_host_dev.begin() + i_r * stride + n); + std::vector qy_host_ref_row(qy_host_ref.begin() + i_r * stride, + qy_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(qy_host_dev_row, + qy_host_ref_row, + std::string("qy[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + } + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride + << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + if(data_type == "fp16") + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp new file mode 100644 index 0000000000..57a0f254d0 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_api.cpp @@ -0,0 +1,157 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "add_rmsnorm2d_rdquant_fwd.hpp" + +template +using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; + +template +float add_rmsnorm2d_rdquant_fwd_b16_(add_rmsnorm2d_rdquant_fwd_traits /*t*/, + add_rmsnorm2d_rdquant_fwd_args a, + const ck_tile::stream_config& s) +{ +#if 1 + float r = -1; + // clang-format off + // rm rn tm tn vn pd x 3p + if(a.n <= 64) { + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 4 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else if (a.n % 2 == 0) + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + else + r = add_rmsnorm2d_rdquant_fwd_>(s, a); + } + return r; +#else + return add_rmsnorm2d_rdquant_fwd_>(s, a); +#endif + // clang-format on +} + +float add_rmsnorm2d_rdquant_fwd(add_rmsnorm2d_rdquant_fwd_traits t, + add_rmsnorm2d_rdquant_fwd_args a, + const ck_tile::stream_config& s) +{ + + float r = -1; + // Only support instance of save_x == true for now + assert(t.save_x); + if(t.data_type.compare("fp16") == 0) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return add_rmsnorm2d_rdquant_fwd_b16_(t, a, s); + } + if(r < 0) + throw std::runtime_error("Without supported instances!"); + + return r; +} diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp new file mode 100644 index 0000000000..5495e3c9ab --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +#if 0 +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +#endif + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp new file mode 100644 index 0000000000..8bbfdc8589 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp new file mode 100644 index 0000000000..381a11fc80 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp new file mode 100644 index 0000000000..2fefac6934 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp new file mode 100644 index 0000000000..263713bbc7 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp new file mode 100644 index 0000000000..c62c596fab --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..e4951f6ab9 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp new file mode 100644 index 0000000000..4c7ee48e8e --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000..8659dc82b3 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp new file mode 100644 index 0000000000..5f15f11b47 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_bf16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp new file mode 100644 index 0000000000..8ffdacbdcd --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +#if 0 +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +#endif + +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp new file mode 100644 index 0000000000..3551099651 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp new file mode 100644 index 0000000000..d4d0474c27 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp new file mode 100644 index 0000000000..2cb300eda6 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp new file mode 100644 index 0000000000..fb0ceb4c58 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp new file mode 100644 index 0000000000..3a241a3c93 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..d3094679f9 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp new file mode 100644 index 0000000000..919bc177e8 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000..8a44f5e00f --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp new file mode 100644 index 0000000000..5c4f05ec3c --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_fp16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "add_rmsnorm2d_rdquant_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd x 3p +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +template float add_rmsnorm2d_rdquant_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp new file mode 100644 index 0000000000..6baaad471a --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/instances/add_rmsnorm2d_rdquant_fwd_instance_common.hpp @@ -0,0 +1,67 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "add_rmsnorm2d_rdquant_fwd.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = add_rmsnorm2d_rdquant_fwd_args; + +template +using trait_ = add_rmsnorm2d_rdquant_fwd_traits_; + +template +float add_rmsnorm2d_rdquant_fwd_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = ck_tile::AddRmsnorm2dRdquantFwdPipelineProblem< + typename AddRmsnormRdquantTypeConfig::ADataType, + typename AddRmsnormRdquantTypeConfig::BDataType, + typename AddRmsnormRdquantTypeConfig::GammaDataType, + typename AddRmsnormRdquantTypeConfig::ComputeDataType, + typename AddRmsnormRdquantTypeConfig::XDataType, + typename AddRmsnormRdquantTypeConfig::YScaleDataType, + typename AddRmsnormRdquantTypeConfig::QYDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kSaveX, + Traits_::kThreePass>; + + using OnePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineOnePass; + using ThreePassPipeline = ck_tile::AddRmsnorm2dRdquantFwdPipelineThreePass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::AddRmsnorm2dRdquantFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh new file mode 100755 index 0000000000..11fd364886 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/perf_test.sh @@ -0,0 +1,38 @@ + +# run from top of ck folder +EXE=build/bin/tile_add_rmsnorm2d_rdquant_fwd + +$EXE -m=1 -n=1 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=bf16 -repeat=1000 + +$EXE -m=700 -n=80 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=128 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=144 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=168 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=184 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=256 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=288 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=344 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=376 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=448 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=512 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=924 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1024 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1078 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=1996 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 +$EXE -m=700 -n=4080 -e=1e-12 -v=1 -prec=fp16 -repeat=1000 \ No newline at end of file diff --git a/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh new file mode 100755 index 0000000000..4a02cdcb65 --- /dev/null +++ b/example/ck_tile/11_add_rmsnorm2d_rdquant/script/smoke_test.sh @@ -0,0 +1,31 @@ +#!/bin/sh +# call from top of CK folder +EXE=./build/bin/tile_add_rmsnorm2d_rdquant_fwd + +for pr_i in "fp16" "bf16" ; do +$EXE -prec=$pr_i -m=99 -n=13 +$EXE -prec=$pr_i -m=17 -n=16 +$EXE -prec=$pr_i -m=1 -n=100 +$EXE -prec=$pr_i -m=4 -n=128 +$EXE -prec=$pr_i -m=80 -n=127 +$EXE -prec=$pr_i -m=22 -n=255 -stride=256 +$EXE -prec=$pr_i -m=7 -n=599 +$EXE -prec=$pr_i -m=19 -n=512 +$EXE -prec=$pr_i -m=33 -n=313 -stride=1000 +$EXE -prec=$pr_i -m=11 -n=510 +$EXE -prec=$pr_i -m=171 -n=676 -stride=818 +$EXE -prec=$pr_i -m=91 -n=636 +$EXE -prec=$pr_i -m=12 -n=768 -stride=800 +$EXE -prec=$pr_i -m=100 -n=766 -stride=812 +$EXE -prec=$pr_i -m=31 -n=1024 +$EXE -prec=$pr_i -m=64 -n=1000 -stride=1004 +$EXE -prec=$pr_i -m=8 -n=1501 +$EXE -prec=$pr_i -m=3 -n=1826 +$EXE -prec=$pr_i -m=5 -n=2040 +$EXE -prec=$pr_i -m=7 -n=2734 +$EXE -prec=$pr_i -m=1 -n=3182 +$EXE -prec=$pr_i -m=9 -n=4096 +$EXE -prec=$pr_i -m=3 -n=8192 +$EXE -prec=$pr_i -m=1 -n=10547 +$EXE -prec=$pr_i -m=3 -n=17134 +done diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index c85e313413..e404e5019e 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -9,4 +9,5 @@ add_subdirectory(04_img2col) add_subdirectory(05_reduce) add_subdirectory(06_permute) add_subdirectory(09_topk_softmax) - +add_subdirectory(10_rmsnorm2d) +add_subdirectory(11_add_rmsnorm2d_rdquant) diff --git a/include/ck_tile/core.hpp b/include/ck_tile/core.hpp index 14991d375a..fa4b8d3cc4 100644 --- a/include/ck_tile/core.hpp +++ b/include/ck_tile/core.hpp @@ -59,6 +59,7 @@ #include "ck_tile/core/utility/magic_div.hpp" #include "ck_tile/core/utility/philox_rand.hpp" #include "ck_tile/core/utility/random.hpp" +#include "ck_tile/core/utility/reduce_operator.hpp" #include "ck_tile/core/utility/to_sequence.hpp" #include "ck_tile/core/utility/transpose_vectors.hpp" #include "ck_tile/core/utility/type_traits.hpp" diff --git a/include/ck_tile/core/utility/reduce_operator.hpp b/include/ck_tile/core/utility/reduce_operator.hpp new file mode 100644 index 0000000000..8b15d187fe --- /dev/null +++ b/include/ck_tile/core/utility/reduce_operator.hpp @@ -0,0 +1,95 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/config.hpp" + +namespace ck_tile { + +namespace ReduceOp { +// y = ReduceOp(y, x); +struct Add +{ + template + CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const + { + return y + x; + } + + template || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const + { + float y_ = type_convert(y); + float x_ = type_convert(x); + + return type_convert(y_ + x_); + } +}; + +struct SquareAdd +{ + template + CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() + { + return type_convert(0.0f); + }; + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const + { + return y + (x * x); + } +}; + +struct Max +{ + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() + { + return numeric::min(); + }; + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const + { + return max(y, x); + } +}; + +struct AbsMax +{ + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue() + { + return numeric::min(); + }; + + template || std::is_same_v || + std::is_same_v || std::is_same_v>> + CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const + { + return max(y, abs(x)); + } +}; + +} // namespace ReduceOp +} // namespace ck_tile diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index a17ce751c2..c0ab13ce3d 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -19,11 +19,14 @@ #include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp" +#include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_permute.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" +#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" +#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/stream_config.hpp" diff --git a/include/ck_tile/host/reference/reference_elementwise.hpp b/include/ck_tile/host/reference/reference_elementwise.hpp new file mode 100644 index 0000000000..809049fa64 --- /dev/null +++ b/include/ck_tile/host/reference/reference_elementwise.hpp @@ -0,0 +1,47 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { +template +CK_TILE_HOST void reference_unary_elementwise(const HostTensor& a, + HostTensor& b, + ElementOp element_op) +{ + // TODO: imeplement gpu version reference function + auto f = [&](auto i) { + auto v_a = type_convert(a.mData[i]); + auto v_b = element_op(v_a); + b.mData[i] = ck_tile::type_convert(v_b); + }; + + make_ParallelTensorFunctor(f, b.get_element_space_size())(std::thread::hardware_concurrency()); +} + +template +CK_TILE_HOST void reference_binary_elementwise(const HostTensor& a, + const HostTensor& b, + HostTensor& c, + ElementOp element_op) +{ + // TODO: imeplement gpu version reference function + auto f = [&](auto i) { + auto v_a = type_convert(a.mData[i]); + auto v_b = type_convert(b.mData[i]); + auto v_c = element_op(v_a, v_b); + c.mData[i] = ck_tile::type_convert(v_c); + }; + + make_ParallelTensorFunctor(f, c.get_element_space_size())(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_reduce.hpp b/include/ck_tile/host/reference/reference_reduce.hpp index b16cee3f94..8f8aa23670 100644 --- a/include/ck_tile/host/reference/reference_reduce.hpp +++ b/include/ck_tile/host/reference/reference_reduce.hpp @@ -9,24 +9,25 @@ namespace ck_tile { -template -CK_TILE_HOST void reference_reduce(const HostTensor& a_m_n, HostTensor& b_m) +template +CK_TILE_HOST void +reference_reduce(const HostTensor& x_m_n, HostTensor& y_m, ReduceOp reduce_op) { auto f = [&](auto m) { - const int N = a_m_n.mDesc.get_lengths()[1]; + const int N = x_m_n.mDesc.get_lengths()[1]; - AccDataType v_acc = 0; + ComputeDataType v_acc = reduce_op.template GetIdentityValue(); for(int n = 0; n < N; ++n) { - const ADataType v_a = a_m_n(m, n); + const ComputeDataType v_a = type_convert(x_m_n(m, n)); - v_acc += v_a; + v_acc = reduce_op(v_acc, v_a); } - b_m(m) = ck_tile::type_convert(v_acc); + y_m(m) = ck_tile::type_convert(v_acc); }; - make_ParallelTensorFunctor(f, b_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); + make_ParallelTensorFunctor(f, y_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); } } // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp new file mode 100644 index 0000000000..db6e92f4c0 --- /dev/null +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, + const HostTensor& gamma_n, + HostTensor& y_m_n, + HostTensor& invRms_m, + ComputeDataType epsilon) +{ + auto rmsnorm2d_fwd_func = [&](auto m) { + const int N = x_m_n.mDesc.get_lengths()[1]; + + ComputeDataType mean_square = 0; + ComputeDataType divisor = 0; + + for(int n = 0; n < N; ++n) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + mean_square += x * x; + } + + mean_square = mean_square / N; + divisor = ck_tile::type_convert(1) / ck_tile::sqrt(mean_square + epsilon); + + if constexpr(!std::is_same_v) + invRms_m(m) = ck_tile::type_convert(divisor); + + for(int n = 0; n < N; ++n) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); + auto y = x * divisor * gamma; + y_m_n(m, n) = ck_tile::type_convert(y); + } + }; + + make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])( + std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp b/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp new file mode 100644 index 0000000000..e9a398876f --- /dev/null +++ b/include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp @@ -0,0 +1,33 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" +#include + +namespace ck_tile { +template +CK_TILE_HOST void reference_rowwise_quantization2d(const HostTensor& x_m_n, + const HostTensor& scale_m, + HostTensor& qx_m_n) +{ + auto f = [&](auto m) { + const int N = x_m_n.mDesc.get_lengths()[1]; + + for(int n = 0; n < N; ++n) + { + auto v_x = x_m_n(m, n); + // scale = amax / 127 for int8 + auto v_scale = type_convert(scale_m(m)); + auto v_qx = v_x / v_scale; + qx_m_n(m, n) = saturates{}(v_qx); + } + }; + + make_ParallelTensorFunctor(f, + scale_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} + +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp new file mode 100644 index 0000000000..eb06fea2dd --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp" +#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp new file mode 100644 index 0000000000..4a0e290352 --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp @@ -0,0 +1,239 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +// host side args +struct AddRmsnorm2dRdquantFwdHostArgs +{ + const void* p_a; + const void* p_b; + const void* p_gamma; + + void* p_x; + void* p_yscale; + void* p_qy; + + float epsilon; + + index_t m; + index_t n; + index_t stride; // row_stride +}; + +// TODO: Extract some type to wrapper class +template +struct AddRmsnorm2dRdquantFwd +{ + using Pipeline = remove_cvref_t; + using Problem = typename Pipeline::Problem; + + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using XDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using QYDataType = remove_cvref_t; + + static constexpr bool kSaveX = Problem::kSaveX; + + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kThreePass = Problem::kThreePass; + + static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; + static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; + static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + struct Kargs + { + const void* p_a; + const void* p_b; + const void* p_gamma; + + void* p_x; + void* p_yscale; + void* p_qy; + + float epsilon; + + index_t m; + index_t n; + index_t stride; // row_stride + }; + using Hargs = AddRmsnorm2dRdquantFwdHostArgs; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) + { + return Kargs{hargs.p_a, + hargs.p_b, + hargs.p_gamma, + hargs.p_x, + hargs.p_yscale, + hargs.p_qy, + hargs.epsilon, + hargs.m, + hargs.n, + hargs.stride}; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) + { + return integer_divide_ceil(hargs.m, Block_M); + } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + // in byte + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + using S_ = typename Problem::BlockShape; + auto surfix = [&] () { + std::string n; + if (kPadN) n += "_pn"; + if (kSaveX) n += "_x"; + if (kThreePass) n += "_2p"; + return n; }(); + + #define _SS_ std::string + #define _TS_ std::to_string + return _SS_("add_rmsnorm2d_rdquant_fwd_") + _SS_(t2s::name) + "_" + + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + + _SS_(Pipeline::name) + surfix; + #undef _SS_ + #undef _TS_ + // clang-format on + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const auto iM = get_block_id() * Block_M; + + const auto a_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_a), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + const auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + const auto b_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_b), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + const auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + const auto gamma_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_gamma), + make_tuple(kargs.n), + make_tuple(1), + number{}, + number<1>{}); + + const auto tmp2_ = + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + + return make_tile_window(tmp2_, make_tuple(number{}), {0}); + }(); + + auto x_window = [&]() { + if constexpr(kSaveX) + { + const auto tmp2_ = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_x), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + return pad_tensor_view(tmp_, + make_tuple(number{}, number{}), + sequence{}); + }(); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + } + else + return make_null_tile_window(make_tuple(number{}, number{})); + }(); + + auto yscale_window = [&]() { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_yscale), + make_tuple(kargs.m), + make_tuple(1), + number<1>{}); + + auto tmp2_ = pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + return make_tile_window(tmp2_, make_tuple(number{}), {iM}); + }(); + + auto qy_window = [&]() { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_qy), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + __shared__ char smem[GetSmemSize()]; + + Pipeline{}(a_window, + b_window, + gamma_window, + x_window, + yscale_window, + qy_window, + static_cast(kargs.epsilon), + kargs.n, + smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp new file mode 100644 index 0000000000..a17c53c73f --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +/* +// clang-format off + +4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector + + Block_N (Warp_N * WarpPerBlock_N * Repeat_N ) + +<----------------------< Repeat_N(2)>--------------------->+ + | | + +<-- -->+ + Warp_N + +--------------+--------------+--------------+--------------+----+----------------+ + Warp_M | wrap_0 | wrap_1 | | ^ ^ + +--------------+--------------+ | | + | wrap_2 | wrap_3 | | v + +--------------+--------------+--------------+--------------+----+ Block_M + | | | + + + | + | | | v + +--------------+--------------+--------------+--------------+ + + + each Warp-tile (e.g 16 thrd per row) + + Vector_N (contiguous pixels each thrd holds along N, or vector size) + +-----------+-----------+-----------+-----------+-----------+ + | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M + +-----------+-----------+-----------+-----------+-----------+ + | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ... + +-----------+-----------+-----------+-----------+-----------+ +// clang-format on +*/ +template + typename WarpPerBlock_, // num warps along seq + typename WarpTile_, // warp size, seq + typename Vector_, // contiguous pixels(vector size) along seq + index_t BlockSize_ = + warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> +struct AddRmsnorm2dRdquantShape +{ + // block size + static constexpr index_t Block_M = BlockTile_::at(number<0>{}); + static constexpr index_t Block_N = BlockTile_::at(number<1>{}); + + // num warps along seq, within each block + static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); + + // warp size + static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); + + static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); + static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); + // repeat of each thread along seq + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + // vector size along seq + static constexpr index_t Vector_M = Vector_::at(number<0>{}); + static constexpr index_t Vector_N = Vector_::at(number<1>{}); + + static_assert(Warp_M % Vector_M == 0); + static_assert(Warp_N % Vector_N == 0); + // num of threads along seq, within each warp + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t BlockSize = BlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000..73ba633b15 --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d.hpp" + +namespace ck_tile { + +struct AddRmsnorm2dRdquantFwdPipelineDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeABXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + template + CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + if constexpr(Problem::kNeedCrossWarpSync) + { + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = + decltype(make_static_distributed_tensor( + MakeABXBlockTileDistribution())); + using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile()); + + return GetBlockReduce2dCrossWarpSync().template GetSmemSize(); + } + else + { + return 1; // zero size arrays are an extension + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp new file mode 100644 index 0000000000..12a15938ae --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct AddRmsnorm2dRdquantFwdPipelineOnePass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using XDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + using QYDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveX = Problem::kSaveX; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_op"; // block per row + else + return "wpr_op"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const AWindow& a_window_, + const BWindow& b_window_, + const GammaWindow& gamma_window_, + XWindow& x_window, + YScaleWindow& yscale_window, + QYWindow& qy_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + const auto a_window = + make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution()); + const auto b_window = + make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution()); + const auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; + auto reduce_absmax_func = ReduceOp::AbsMax{}; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + const auto a = load_tile(a_window); + const auto b = load_tile(b_window); + const auto gamma = load_tile(gamma_window); + + auto x = tile_elementwise_in( + [&](const auto& a_, const auto& b_) { + return type_convert(a_) + type_convert(b_); + }, + a, + b); + + if constexpr(kSaveX) + store_tile(x_window, cast_tile(x)); + + // compute mean square, each-thread->cross-lane->cross-warp + auto square_sum = block_reduce2d( + x, reduce_square_sum_func.GetIdentityValue(), reduce_square_sum_func); + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + }, + square_sum); + + // rmsnorm computation + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms_[i_idx] * gamma_; + + y(idx) = type_convert(y_); + }); + + // compute absmax, each-thread->cross-lane->cross-warp + auto absmax = block_reduce2d( + y, reduce_absmax_func.GetIdentityValue(), reduce_absmax_func); + block_reduce2d_sync(absmax, reduce_max_func); + block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); + + // ex: yscale = absmax / 127 if int8 + auto yscale = tile_elementwise_in( + [&](const auto& v_) { + return v_ / type_convert(numeric::max()); + }, + absmax); + store_tile(yscale_window, cast_tile(yscale)); + + // quantize y to qy + auto qy = make_static_distributed_tensor(y.get_tile_distribution()); + sweep_tile(qy, [&, yscale_ = yscale](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + auto qy_ = y[idx] / yscale_[i_idx]; + qy(idx) = saturates{}(qy_); + }); + store_tile(qy_window, qy); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp new file mode 100644 index 0000000000..106e5086be --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp @@ -0,0 +1,41 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale) +template +struct AddRmsnorm2dRdquantFwdPipelineProblem +{ + using ADataType = remove_cvref_t; + using BDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using XDataType = remove_cvref_t; + using YScaleDataType = remove_cvref_t; + using QYDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveX = kSaveX_; + static constexpr bool kThreePass = kThreePass_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp new file mode 100644 index 0000000000..0dbb20645a --- /dev/null +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -0,0 +1,266 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct AddRmsnorm2dRdquantFwdPipelineThreePass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using ADataType = ck_tile::remove_cvref_t; + using BDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using XDataType = ck_tile::remove_cvref_t; + using YScaleDataType = ck_tile::remove_cvref_t; + using QYDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveX = Problem::kSaveX; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_tp"; // block per row + else + return "wpr_tp"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const AWindow& a_window_, + const BWindow& b_window_, + const GammaWindow& gamma_window_, + XWindow& x_window_, + YScaleWindow& yscale_window, + QYWindow& qy_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + auto a_window = + make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution()); + auto b_window = + make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution()); + auto x_window = [&]() { + if constexpr(kSaveX) + return make_tile_window(x_window_, + Policy::template MakeABXBlockTileDistribution()); + else + return x_window_; + }(); + auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; + auto reduce_absmax_func = ReduceOp::AbsMax{}; + auto reduce_max_func = ReduceOp::Max{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + + using XTensorType = decltype(cast_tile(load_tile(a_window))); + auto square_sum = block_reduce2d.template MakeYBlockTile(); + set_tile(square_sum, reduce_square_sum_func.GetIdentityValue()); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto a = load_tile(a_window); + const auto b = load_tile(b_window); + + auto x = tile_elementwise_in( + [&](const auto& a_, const auto& b_) { + return type_convert(a_) + type_convert(b_); + }, + a, + b); + + if constexpr(kSaveX) + store_tile(x_window, cast_tile(x)); + + block_reduce2d(x, square_sum, reduce_square_sum_func); + move_tile_window(x_window, {0, Block_N}); + move_tile_window(a_window, {0, Block_N}); + move_tile_window(b_window, {0, Block_N}); + } + + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + }, + square_sum); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = + row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; + + if constexpr(kSaveX) + move_tile_window(x_window, {0, -Block_N}); + else + { + move_tile_window(a_window, {0, -Block_N}); + move_tile_window(b_window, {0, -Block_N}); + } + move_tile_window(gamma_window, {stride_to_right_most_window}); + + using YTensorType = XTensorType; + auto absmax = block_reduce2d.template MakeYBlockTile(); + set_tile(absmax, reduce_absmax_func.GetIdentityValue()); + + // rmsnorm computation + absmax(threadwise reduce) + if constexpr(kSaveX) + __syncthreads(); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + auto x = [&]() { + if constexpr(kSaveX) + { + return load_tile(x_window); + } + else + { + const auto a = load_tile(a_window); + const auto b = load_tile(b_window); + return tile_elementwise_in( + [&](const auto& a_, const auto& b_) { + return type_convert(a_) + + type_convert(b_); + }, + a, + b); + } + }(); + + auto gamma = load_tile(gamma_window); + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + + sweep_tile(y, [&](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms[i_idx] * gamma_; + + y(idx) = type_convert(y_); + }); + + block_reduce2d(y, absmax, reduce_absmax_func); + + if constexpr(kSaveX) + move_tile_window(x_window, {0, -Block_N}); + else + { + move_tile_window(a_window, {0, -Block_N}); + move_tile_window(b_window, {0, -Block_N}); + } + move_tile_window(gamma_window, {-Block_N}); + } + + // compute absmax, cross-lane->cross-warp + block_reduce2d_sync(absmax, reduce_max_func); + block_reduce2d_cross_warp_sync(absmax, smem, reduce_max_func); + + // ex: yscale = absmax / 127 if int8 + auto yscale = tile_elementwise_in( + [&](const auto& v_) { + return v_ / type_convert(numeric::max()); + }, + absmax); + store_tile(yscale_window, cast_tile(yscale)); + + // quantize y to qy + // recompute rmsnorm, try to save y in the future + if constexpr(kSaveX) + move_tile_window(x_window, {0, Block_N}); + else + { + move_tile_window(a_window, {0, Block_N}); + move_tile_window(b_window, {0, Block_N}); + } + move_tile_window(gamma_window, {Block_N}); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + auto x = [&]() { + if constexpr(kSaveX) + { + return load_tile(x_window); + } + else + { + const auto a = load_tile(a_window); + const auto b = load_tile(b_window); + return tile_elementwise_in( + [&](const auto& a_, const auto& b_) { + return type_convert(a_) + + type_convert(b_); + }, + a, + b); + } + }(); + + auto gamma = load_tile(gamma_window); + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + auto qy = make_static_distributed_tensor(y.get_tile_distribution()); + + sweep_tile(y, [&](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms[i_idx] * gamma_; + auto qy_ = y_ / yscale[i_idx]; + qy(idx) = saturates{}(qy_); + }); + + store_tile(qy_window, qy); + + if constexpr(kSaveX) + move_tile_window(x_window, {0, Block_N}); + else + { + move_tile_window(a_window, {0, Block_N}); + move_tile_window(b_window, {0, Block_N}); + } + move_tile_window(gamma_window, {Block_N}); + move_tile_window(qy_window, {0, Block_N}); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp index bf002141b8..c767a472a9 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp @@ -35,9 +35,9 @@ struct Layernorm2dFwdPipelineOnePass static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) - return "bpr"; // block per row + return "bpr_op"; // block per row else - return "wpr"; // warp per row + return "wpr_op"; // warp per row }(); CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() diff --git a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp index db094ac2a8..e35d02e707 100644 --- a/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp +++ b/include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp @@ -35,9 +35,9 @@ struct Layernorm2dFwdPipelineTwoPass static constexpr const char* name = []() { if constexpr(kNeedCrossWarpSync) - return "bpr"; // block per row + return "bpr_tp"; // block per row else - return "wpr"; // warp per row + return "wpr_tp"; // warp per row }(); CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() @@ -118,8 +118,6 @@ struct Layernorm2dFwdPipelineTwoPass ck_tile::index_t stride_to_right_most_window = row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; - // x_window.foo(); - // gamma_window.foo(); move_tile_window(x_window, {0, -Block_N}); move_tile_window(gamma_window, {stride_to_right_most_window}); move_tile_window(beta_window, {stride_to_right_most_window}); diff --git a/include/ck_tile/ops/reduce.hpp b/include/ck_tile/ops/reduce.hpp index a5ba745d29..fe2d24044e 100644 --- a/include/ck_tile/ops/reduce.hpp +++ b/include/ck_tile/ops/reduce.hpp @@ -4,4 +4,7 @@ #pragma once #include "ck_tile/ops/reduce/block/block_reduce.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/reduce/block/block_reduce.hpp b/include/ck_tile/ops/reduce/block/block_reduce.hpp index 51d55235e8..d9df949cf9 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce.hpp @@ -6,6 +6,7 @@ #include "ck_tile/core.hpp" #include +// This file is not support cross warp reduce namespace ck_tile { /* @@ -15,8 +16,8 @@ namespace ck_tile { // synchronize reduce result (cross lane reduction and broadcast on replicated dimension) template CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, - const ReduceFunc& reduce_func, - bool_constant = {}) + const ReduceFunc& reduce_func, + bool_constant = {}) { using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -115,7 +116,7 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor, */ template CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor, - const ReduceFunc& reduce_func) + const ReduceFunc& reduce_func) { using Dstr = typename AccDistributedTensor_::StaticTileDistribution; using DstrEncode = typename Dstr::DstrEncode; @@ -174,9 +175,9 @@ template CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor, - const InDistributedTensor_& in_tensor, - sequence, - const ReduceFunc& reduce_func) + const InDistributedTensor_& in_tensor, + sequence, + const ReduceFunc& reduce_func) { constexpr auto I0 = number<0>{}; constexpr auto I1 = number<1>{}; @@ -249,9 +250,9 @@ template CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor, - sequence in_reduce_dims, - const ReduceFunc& reduce_func, - const InDataType_& reduce_init) + sequence in_reduce_dims, + const ReduceFunc& reduce_func, + const InDataType_& reduce_init) { using InDataType = typename InDistributedTensor_::DataType; using AccDataType = remove_cvref_t; diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp new file mode 100644 index 0000000000..beb8c718e3 --- /dev/null +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -0,0 +1,260 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockReduce2d +{ + // in-thread reduction + using Problem = remove_cvref_t; + using XDataType = typename Problem::XDataType; + using ComputeDataType = typename Problem::ComputeDataType; + + CK_TILE_DEVICE constexpr BlockReduce2d() {} + + template + CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor, + YDistributedTensor_& y_tensor, + const ReduceFunc& reduce_func) + { + constexpr auto I0 = number<0>{}; + constexpr auto I1 = number<1>{}; + + constexpr auto spans = XDistributedTensor_::get_distributed_spans(); + + // FIXME: hard coded to reduce 2nd axis + sweep_tile_span(spans[I0], [&](auto dstr_idx_i0) { + constexpr auto y_dstr_idx = make_tuple(dstr_idx_i0); + + auto y = y_tensor[y_dstr_idx]; + + sweep_tile_span(spans[I1], [&](auto dstr_idx_i1) { + constexpr auto in_dstr_idx = make_tuple(dstr_idx_i0, dstr_idx_i1); + const auto x = ck_tile::type_convert(x_tensor[in_dstr_idx]); + + y = reduce_func(y, x); + }); + + y_tensor(y_dstr_idx) = y; + }); + } + + template + CK_TILE_DEVICE static auto MakeYBlockTile() + { + static_assert(std::is_same_v, "wrong!"); + + // FIXME: hard coded to reduce 2nd axis + constexpr auto reduce_dims = sequence<1>{}; + + constexpr auto dstr = + make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding( + XDistributedTensor_::get_tile_distribution() + .get_static_tile_distribution_encoding(), + reduce_dims)); + + auto tensor = make_static_distributed_tensor(dstr); + + return tensor; + } + + template + CK_TILE_DEVICE auto operator()(const XDistributedTensor_& x_tensor, + const ComputeDataType& reduce_init, + const ReduceFunc& reduce_func) + { + auto y_tensor = MakeYBlockTile(); + set_tile(y_tensor, reduce_init); + (*this)(x_tensor, y_tensor, reduce_func); + + return y_tensor; + } +}; + +template +struct BlockReduce2dSync +{ + using Problem = remove_cvref_t; + + template + CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func) + { + using Dstr = typename YDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + constexpr index_t NDimP = Dstr::get_num_of_dimension_p(); + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_lane = NDimP - 1; + + // const auto ps_idx = make_array(get_warp_id(), get_lane_id()); + // const auto rs_idx = + // y_tensor.get_tile_distribution().calculate_rs_index_from_ps_index(ps_idx); + + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + + // loop over thread data + static_for<0, thread_buf_size, 1>{}([&](auto i) { + auto v_local = y_tensor.get_thread_buffer()[i]; + + // cross-lane reduce for replication + // only reduce on R dimension correspond to lane + // (lane id maps to this R dimension) + static_for<0, NDimR, 1>{}([&](auto idim_r) { + // FIXME: nasty to use does_p_own_r_ + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + + constexpr index_t lid_over_rid_derivative = + DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r]; + + static_assert(is_power_of_two_integer(r_length), + "wrong! only support power of 2 reduction"); + + constexpr index_t nstage = integer_log2_floor(r_length); + + // reduction sweep forward + static_for<0, nstage, 1>{}([&](auto istage) { + // xor + index_t src_lane = + (__lane_id()) ^ + (number{}.value); + + // pull data from remote lane + const auto v_remote = warp_shuffle(v_local, src_lane); + + // reduce + v_local = reduce_func(v_local, v_remote); + }); + } + }); + + // TODO - Do we need to broadcast to other lane? + y_tensor.get_thread_buffer()(i) = v_local; + }); + } +}; + +template +struct BlockReduce2dCrossWarpSync +{ + using Problem = remove_cvref_t; + using BlockShape = typename Problem::BlockShape; + + template + CK_TILE_DEVICE static constexpr index_t GetReduceWarps() + { + constexpr index_t num_reduce_warps = [&]() { + using Dstr = typename YDistributedTensor_::StaticTileDistribution; + using DstrEncode = typename Dstr::DstrEncode; + using DstrEncodeDetail = typename DstrEncode::detail; + + constexpr index_t NDimR = Dstr::get_num_of_dimension_r(); + + constexpr index_t idim_p_warp = 0; + + index_t len_ = 1; + static_for<0, NDimR, 1>{}([&](auto idim_r) { + if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_warp][idim_r]) + { + constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r]; + len_ *= r_length; + } + }); + return len_; + }(); + return num_reduce_warps; + } + + // return in byte + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + using DataType = typename YDistributedTensor_::DataType; + // constexpr auto num_reduce_warps = GetReduceWarps(); + + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + + // we need to store all data from every wave into smem + // e.g. 2x2 reduce along N + // -------------> reduce N + // | w0 | w1 | ___> | w01 | + // | w2 | w3 | | w23 | + // + // -> store data from every wave into LDS + // + // + // -------------> reduce N + // | w0 | w1 | w2 | w3 | -----> | w0123 | + // + // -> also store data from every wave into LDS + constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + return num_warps * thread_buf_size * sizeof(DataType); + } + + template + CK_TILE_DEVICE void + operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func) + { + using DataType = typename YDistributedTensor_::DataType; + + constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size(); + + DataType* smem_ptr = reinterpret_cast(smem); + const index_t lane_id = get_lane_id(); + const index_t warp_id = get_warp_id(); + constexpr auto num_reduce_warps = GetReduceWarps(); + constexpr index_t num_warps = BlockShape::BlockSize / warpSize; + const index_t smem_offset = warp_id; + + // skip if nonthing to do + if constexpr(num_reduce_warps == 1) + return; + + // store into smem only for lane-0 within one warp + if(lane_id == 0) + { + static_for<0, thread_buf_size, 1>{}([&](auto i) { + smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i]; + }); + } + block_sync_lds(); + + // load from smem. here we let everythread to do compute :) + index_t local_warp_id = warp_id / num_reduce_warps; + index_t local_smem_os = local_warp_id * num_reduce_warps; + DataType all_scratch[thread_buf_size * num_reduce_warps]; + static_for<0, thread_buf_size, 1>{}([&](auto i_0) { + static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { + all_scratch[i_0 * num_reduce_warps + i_1] = + smem_ptr[i_0 * num_warps + local_smem_os + i_1]; + }); + }); + block_sync_lds(); // TODO: we don't need sync here + + static_for<0, thread_buf_size, 1>{}([&](auto i_0) { + // TODO: use descriptor for this + auto v_local = all_scratch[i_0 * num_reduce_warps]; + + // further reduce mean/var + static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { + constexpr auto i_1 = number{}; + const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; + + // reduce + v_local = reduce_func(v_local, v_remote); + }); + + y_tensor.get_thread_buffer()(i_0) = v_local; + }); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp new file mode 100644 index 0000000000..3c547242d5 --- /dev/null +++ b/include/ck_tile/ops/reduce/block/block_reduce2d_default_policy.hpp @@ -0,0 +1,79 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d.hpp" + +namespace ck_tile { + +struct BlockReduce2dDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + if constexpr(Problem::kNeedCrossWarpSync) + { + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = + decltype(make_static_distributed_tensor( + MakeXBlockTileDistribution())); + using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile()); + + return GetBlockReduce2dCrossWarpSync().template GetSmemSize(); + } + else + { + return 1; // zero size arrays are an extension + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp new file mode 100644 index 0000000000..b75f4f0767 --- /dev/null +++ b/include/ck_tile/ops/reduce/block/block_reduce2d_problem.hpp @@ -0,0 +1,18 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { + +template +struct BlockReduce2dProblem +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp new file mode 100644 index 0000000000..98c60f1b51 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp" +#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp new file mode 100644 index 0000000000..99084a25e4 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -0,0 +1,202 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +// host side args +struct Rmsnorm2dFwdHostArgs +{ + const void* p_x; + const void* p_gamma; + + void* p_y; + void* p_invRms; + + float epsilon; + + index_t m; + index_t n; + index_t stride; // row_stride +}; + +// TODO: Extract some type to wrapper class +template +struct Rmsnorm2dFwd +{ + using Pipeline = remove_cvref_t; + using Problem = typename Pipeline::Problem; + + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using InvRmsDataType = remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::kSaveInvRms; + + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kTwoPass = Problem::kTwoPass; + + static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; + static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; + static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + struct Kargs + { + const void* p_x; + const void* p_gamma; + + void* p_y; + void* p_invRms; + + float epsilon; + + index_t m; + index_t n; + index_t stride; // row_stride + }; + using Hargs = Rmsnorm2dFwdHostArgs; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) + { + return Kargs{hargs.p_x, + hargs.p_gamma, + hargs.p_y, + hargs.p_invRms, + hargs.epsilon, + hargs.m, + hargs.n, + hargs.stride}; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) + { + return (hargs.m + Block_M - 1) / Block_M; + } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + // in byte + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + using S_ = typename Problem::BlockShape; + auto surfix = [&] () { + std::string n; + if (kPadN) n += "_pn"; + if (kSaveInvRms) n += "_rms"; + if (kTwoPass) n += "_2p"; + return n; }(); + + #define _SS_ std::string + #define _TS_ std::to_string + return _SS_("rmsnorm2d_fwd_") + _SS_(t2s::name) + "_" + + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + + _SS_(Pipeline::name) + surfix; + #undef _SS_ + #undef _TS_ + // clang-format on + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const auto iM = get_block_id() * Block_M; + + const auto x_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_x), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + const auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + const auto gamma_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_gamma), + make_tuple(kargs.n), + make_tuple(1), + number{}, + number<1>{}); + + const auto tmp2_ = + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + + return make_tile_window(tmp2_, make_tuple(number{}), {0}); + }(); + + auto y_window = [&]() { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_y), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + auto inv_rms_window = [&]() { + if constexpr(kSaveInvRms) + { + const auto inv_rms_m = [&]() { + const auto inv_rms_dram_naive = + make_naive_tensor_view_packed( + static_cast(kargs.p_invRms), + make_tuple(kargs.m), + number<1>{}); + + return pad_tensor_view( + inv_rms_dram_naive, make_tuple(number{}), sequence{}); + }(); + return make_tile_window(inv_rms_m, make_tuple(number{}), {iM}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + + __shared__ char smem[GetSmemSize()]; + + Pipeline{}(x_window, + gamma_window, + y_window, + inv_rms_window, + static_cast(kargs.epsilon), + kargs.n, + smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp new file mode 100644 index 0000000000..fb484a1069 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +/* +// clang-format off + +4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector + + Block_N (Warp_N * WarpPerBlock_N * Repeat_N ) + +<----------------------< Repeat_N(2)>--------------------->+ + | | + +<-- -->+ + Warp_N + +--------------+--------------+--------------+--------------+----+----------------+ + Warp_M | wrap_0 | wrap_1 | | ^ ^ + +--------------+--------------+ | | + | wrap_2 | wrap_3 | | v + +--------------+--------------+--------------+--------------+----+ Block_M + | | | + + + | + | | | v + +--------------+--------------+--------------+--------------+ + + + each Warp-tile (e.g 16 thrd per row) + + Vector_N (contiguous pixels each thrd holds along N, or vector size) + +-----------+-----------+-----------+-----------+-----------+ + | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M + +-----------+-----------+-----------+-----------+-----------+ + | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ... + +-----------+-----------+-----------+-----------+-----------+ +// clang-format on +*/ +template + typename WarpPerBlock_, // num warps along seq + typename WarpTile_, // warp size, seq + typename Vector_, // contiguous pixels(vector size) along seq + index_t BlockSize_ = + warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> +struct Rmsnorm2dShape +{ + // block size + static constexpr index_t Block_M = BlockTile_::at(number<0>{}); + static constexpr index_t Block_N = BlockTile_::at(number<1>{}); + + // num warps along seq, within each block + static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); + + // warp size + static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); + + static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); + static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); + // repeat of each thread along seq + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + // vector size along seq + static constexpr index_t Vector_M = Vector_::at(number<0>{}); + static constexpr index_t Vector_N = Vector_::at(number<1>{}); + + static_assert(Warp_M % Vector_M == 0); + static_assert(Warp_N % Vector_N == 0); + // num of threads along seq, within each warp + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t BlockSize = BlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000..e4814cf455 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/reduce/block/block_reduce2d.hpp" + +namespace ck_tile { + +struct Rmsnorm2dFwdPipelineDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + template + CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + if constexpr(Problem::kNeedCrossWarpSync) + { + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = + decltype(make_static_distributed_tensor( + MakeXBlockTileDistribution())); + using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile()); + + return GetBlockReduce2dCrossWarpSync().template GetSmemSize(); + } + else + { + return 1; // zero size arrays are an extension + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp new file mode 100644 index 0000000000..68cfe4282b --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -0,0 +1,101 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct Rmsnorm2dFwdPipelineOnePass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using InvRmsDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::kSaveInvRms; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_op"; // block per row + else + return "wpr_op"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const GammaWindow& gamma_window_, + YWindow& y_window, + InvRmsWindow& inv_rms_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + const auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + const auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + const auto x = load_tile(x_window); + // load gamma (TODO: support no gamma?) + const auto gamma = load_tile(gamma_window); + + // compute mean square each-thread->cross-lane->cross-warp + auto square_sum = block_reduce2d( + x, reduce_square_sum_func.GetIdentityValue(), reduce_square_sum_func); + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + + // compute inv-rms + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + }, + square_sum); + + if constexpr(kSaveInvRms) + store_tile(inv_rms_window, cast_tile(inv_rms)); + + // rmsnorm computation + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms_[i_idx] * gamma_; + + y(idx) = type_convert(y_); + }); + store_tile(y_window, y); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp new file mode 100644 index 0000000000..87cab34631 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct Rmsnorm2dFwdPipelineProblem +{ + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using InvRmsDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp new file mode 100644 index 0000000000..a892df6bdb --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct Rmsnorm2dFwdPipelineTwoPass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using InvRmsDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::kSaveInvRms; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr_tp"; // block per row + else + return "wpr_tp"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const GammaWindow& gamma_window_, + YWindow& y_window, + InvRmsWindow& inv_rms_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + + // Problem::BlockShape + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + + auto reduce_square_sum_func = ReduceOp::SquareAdd{}; + auto reduce_sum_func = ReduceOp::Add{}; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + using XTensorType = decltype(load_tile(x_window)); + auto square_sum = block_reduce2d.template MakeYBlockTile(); + set_tile(square_sum, reduce_square_sum_func.GetIdentityValue()); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + block_reduce2d(x, square_sum, reduce_square_sum_func); + move_tile_window(x_window, {0, Block_N}); + } + + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + + // compute inv-rms + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + }, + square_sum); + + if constexpr(kSaveInvRms) + store_tile(inv_rms_window, cast_tile(inv_rms)); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = + row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(gamma_window, {stride_to_right_most_window}); + move_tile_window(y_window, {0, stride_to_right_most_window}); + + // rmsnorm computation + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + // load gamma/beta (TODO: support no gamma/beta?) + const auto gamma = load_tile(gamma_window); + + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + + sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms_[i_idx] * gamma_; + + y(idx) = type_convert(y_); + }); + + store_tile(y_window, y); + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(gamma_window, {-Block_N}); + move_tile_window(y_window, {0, -Block_N}); + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/welford/block/block_welford.hpp b/include/ck_tile/ops/welford/block/block_welford.hpp index 55d55402d8..623e1e16d8 100644 --- a/include/ck_tile/ops/welford/block/block_welford.hpp +++ b/include/ck_tile/ops/welford/block/block_welford.hpp @@ -276,8 +276,8 @@ struct BlockWelfordCrossWarpSync fp32x4_t all_scratch[thread_buf_size * num_reduce_warps]; static_for<0, thread_buf_size, 1>{}([&](auto i_0) { static_for<0, num_reduce_warps, 1>{}([&](auto i_1) { - all_scratch[i_0 * num_warps + i_1] = - smem_ptr[i_0 * num_reduce_warps + local_smem_os + i_1]; + all_scratch[i_0 * num_reduce_warps + i_1] = + smem_ptr[i_0 * num_warps + local_smem_os + i_1]; }); }); block_sync_lds(); // TODO: we don't need sync here @@ -286,7 +286,7 @@ struct BlockWelfordCrossWarpSync static_for<0, thread_buf_size, 1>{}([&](auto i_0) { // TODO: use descriptor for this - auto v_local = all_scratch[i_0 * num_warps]; + auto v_local = all_scratch[i_0 * num_reduce_warps]; auto v_local_mean = bit_cast(v_local[0]); auto v_local_var = bit_cast(v_local[1]); auto v_local_count = bit_cast(v_local[2]); @@ -294,7 +294,7 @@ struct BlockWelfordCrossWarpSync // further reduce mean/var static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) { constexpr auto i_1 = number{}; - const fp32x4_t v_remote = all_scratch[i_0 * num_warps + i_1]; + const fp32x4_t v_remote = all_scratch[i_0 * num_reduce_warps + i_1]; const auto v_remote_mean = bit_cast(v_remote[0]); const auto v_remote_var = bit_cast(v_remote[1]); const auto v_remote_count = bit_cast(v_remote[2]);