From 382a2af2129eab35b16a7cc49e51c671fb0b3874 Mon Sep 17 00:00:00 2001 From: rocking Date: Tue, 22 Oct 2024 21:18:42 +0000 Subject: [PATCH] Add rmsnorm2d --- example/ck_tile/06_rmsnorm2d/CMakeLists.txt | 21 ++ example/ck_tile/06_rmsnorm2d/README.md | 22 ++ .../instances/rmsnorm2d_fwd_api.cpp | 153 +++++++++++++ .../rmsnorm2d_fwd_bf16_n1024_instance.cpp | 22 ++ .../rmsnorm2d_fwd_bf16_n1536_instance.cpp | 13 ++ .../rmsnorm2d_fwd_bf16_n2048_instance.cpp | 14 ++ .../rmsnorm2d_fwd_bf16_n256_instance.cpp | 12 ++ .../rmsnorm2d_fwd_bf16_n3072_instance.cpp | 14 ++ .../rmsnorm2d_fwd_bf16_n4096_instance.cpp | 14 ++ .../rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp | 14 ++ .../rmsnorm2d_fwd_bf16_n512_instance.cpp | 13 ++ .../rmsnorm2d_fwd_bf16_n64_n128_instance.cpp | 12 ++ .../rmsnorm2d_fwd_bf16_n768_instance.cpp | 12 ++ .../rmsnorm2d_fwd_fp16_n1024_instance.cpp | 22 ++ .../rmsnorm2d_fwd_fp16_n1536_instance.cpp | 13 ++ .../rmsnorm2d_fwd_fp16_n2048_instance.cpp | 14 ++ .../rmsnorm2d_fwd_fp16_n256_instance.cpp | 12 ++ .../rmsnorm2d_fwd_fp16_n3072_instance.cpp | 14 ++ .../rmsnorm2d_fwd_fp16_n4096_instance.cpp | 14 ++ .../rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp | 14 ++ .../rmsnorm2d_fwd_fp16_n512_instance.cpp | 13 ++ .../rmsnorm2d_fwd_fp16_n64_n128_instance.cpp | 12 ++ .../rmsnorm2d_fwd_fp16_n768_instance.cpp | 12 ++ .../rmsnorm2d_fwd_instance_common.hpp | 65 ++++++ .../ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.cpp | 179 +++++++++++++++ .../ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.hpp | 117 ++++++++++ example/ck_tile/CMakeLists.txt | 1 + include/ck_tile/host.hpp | 1 + .../reference/reference_rmsnorm2d_fwd.hpp | 52 +++++ include/ck_tile/ops/rmsnorm2d.hpp | 12 ++ .../rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp | 204 ++++++++++++++++++ .../rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp | 78 +++++++ .../rmsnorm2d_fwd_pipeline_default_policy.hpp | 94 ++++++++ .../rmsnorm2d_fwd_pipeline_one_pass.hpp | 100 +++++++++ .../rmsnorm2d_fwd_pipeline_problem.hpp | 36 ++++ .../rmsnorm2d_fwd_pipeline_two_pass.hpp | 131 +++++++++++ 36 files changed, 1546 insertions(+) create mode 100644 example/ck_tile/06_rmsnorm2d/CMakeLists.txt create mode 100644 example/ck_tile/06_rmsnorm2d/README.md create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp create mode 100644 example/ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.cpp create mode 100644 example/ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.hpp create mode 100644 include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp create mode 100644 include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp diff --git a/example/ck_tile/06_rmsnorm2d/CMakeLists.txt b/example/ck_tile/06_rmsnorm2d/CMakeLists.txt new file mode 100644 index 0000000000..db743f519e --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/CMakeLists.txt @@ -0,0 +1,21 @@ +set(EXAMPLE_RMSNORM2D_FWD "tile_example_rmsnorm2d_fwd") +# not using add_example_executable() to add this target, since we don't want this to have +# to be included in "make all/install/check" +message("adding example ${EXAMPLE_RMSNORM2D_FWD}") +file(GLOB INSTANCE_SRCS instances/*.cpp) +add_executable(${EXAMPLE_RMSNORM2D_FWD} EXCLUDE_FROM_ALL rmsnorm2d_fwd.cpp) +target_include_directories(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR}) +target_sources(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${INSTANCE_SRCS}) + +set(EXAMPLE_RMSNORM2D_FWD_COMPILE_OPTIONS) + +# NOTE: we turn off undefined-func-template to let source compile without explicit declare function specializations +list(APPEND EXAMPLE_RMSNORM2D_FWD_COMPILE_OPTIONS -Wno-undefined-func-template -Wno-float-equal) + +target_compile_options(${EXAMPLE_RMSNORM2D_FWD} PRIVATE ${EXAMPLE_RMSNORM2D_FWD_COMPILE_OPTIONS}) + +# TODO: we have to turn off this global prop, otherwise the progress bar generated +# by cmake will print too many files, execvp: /bin/sh: Argument list too long +# however, this property may affect global +# TODO: consider codegen a makefile by us +set_property(GLOBAL PROPERTY RULE_MESSAGES OFF) diff --git a/example/ck_tile/06_rmsnorm2d/README.md b/example/ck_tile/06_rmsnorm2d/README.md new file mode 100644 index 0000000000..c760a4f075 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/README.md @@ -0,0 +1,22 @@ +# Rmsnorm2D forward + +This folder contains example for Rmsnorm2D forward using ck_tile tile-programming implementation. + +## build +``` +# in the root of ck_tile +mkdir build && cd build +sh ../script/cmake-ck-dev.sh ../ # you can replace this to gfx90a, gfx942... +make tile_example_rmsnorm2d_fwd -j +``` +This will result in an executable `build/bin/tile_example_rmsnorm2d_fwd` + +## example +``` +args: + -m m dimension (default:3328) + -n m dimension (default:4096) + -e epsilon (default:1e-5) + -v cpu validation or not (default:1) + -prec precision (default:fp16) +``` diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp new file mode 100644 index 0000000000..f9cfe72ded --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_api.cpp @@ -0,0 +1,153 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" + +template +using trait_ = rmsnorm2d_fwd_traits_; + +template +float rmsnorm2d_fwd_b16_(rmsnorm2d_fwd_traits /*t*/, + rmsnorm2d_fwd_args a, + const ck_tile::stream_config& s) +{ +#if 1 + float r = -1; + // clang-format off + // rm rn tm tn vn pd rms 2p + if(a.n <= 64) { + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 128) { + if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 256) { + if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 512) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 768) { + if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 1024) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 1536) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 2048) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 3072) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n <= 4096) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + else if(a.n > 4096) { + if (a.n % 8 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 4 == 0) + r = rmsnorm2d_fwd_>(s, a); + else if (a.n % 2 == 0) + r = rmsnorm2d_fwd_>(s, a); + else + r = rmsnorm2d_fwd_>(s, a); + } + return r; +#else + return rmsnorm2d_fwd_>(s, a); +#endif + // clang-format on +} + +float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, rmsnorm2d_fwd_args a, const ck_tile::stream_config& s) +{ + + float r = -1; + if(t.data_type.compare("fp16") == 0) + { + return rmsnorm2d_fwd_b16_(t, a, s); + } + else if(t.data_type.compare("bf16") == 0) + { + return rmsnorm2d_fwd_b16_(t, a, s); + } + if(r < 0) + throw std::runtime_error("Without supported instances!"); + + return r; +} diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp new file mode 100644 index 0000000000..5e2a35f9e8 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +#if 0 +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +template float rmsnorm2d_fwd_>(const S&, A); +#endif + +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp new file mode 100644 index 0000000000..8c734806e1 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp new file mode 100644 index 0000000000..9222001433 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp new file mode 100644 index 0000000000..ed33c84923 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp new file mode 100644 index 0000000000..b753bbc345 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp new file mode 100644 index 0000000000..27cb9bdf3d --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..23afb5672b --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp new file mode 100644 index 0000000000..b428f58051 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp new file mode 100644 index 0000000000..3001106697 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp new file mode 100644 index 0000000000..e9c8d6a1d4 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_bf16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp new file mode 100644 index 0000000000..15198eebe6 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1024_instance.cpp @@ -0,0 +1,22 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +#if 0 +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +template float rmsnorm2d_fwd_>(const S&, A); +#endif + +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp new file mode 100644 index 0000000000..8ac85fa9b5 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n1536_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp new file mode 100644 index 0000000000..10e8fafc2f --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n2048_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp new file mode 100644 index 0000000000..4e1a80bf64 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n256_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp new file mode 100644 index 0000000000..45e56a92b8 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n3072_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp new file mode 100644 index 0000000000..35401f6f82 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp new file mode 100644 index 0000000000..1e3700fad3 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n4096_tp_instance.cpp @@ -0,0 +1,14 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); + +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp new file mode 100644 index 0000000000..cdc4d00bd2 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n512_instance.cpp @@ -0,0 +1,13 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp new file mode 100644 index 0000000000..ec80c2ee4a --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n64_n128_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp new file mode 100644 index 0000000000..ddfc5a54e8 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_fp16_n768_instance.cpp @@ -0,0 +1,12 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "rmsnorm2d_fwd_instance_common.hpp" + +// clang-format off +// rm rn tm tn vn pd rms 2p +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +template float rmsnorm2d_fwd_>(const S&, A); +// clang-format on diff --git a/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp new file mode 100644 index 0000000000..7619ca3dc9 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/instances/rmsnorm2d_fwd_instance_common.hpp @@ -0,0 +1,65 @@ + +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#include +#include "rmsnorm2d_fwd.hpp" +#include + +#pragma once + +using S = ck_tile::stream_config; +using A = rmsnorm2d_fwd_args; + +template +using trait_ = rmsnorm2d_fwd_traits_; + +template +float rmsnorm2d_fwd_(const S& s, A a) +{ + using DataType = typename Traits_::DataType; + + using PipelineProblem = + ck_tile::Rmsnorm2dFwdPipelineProblem::XDataType, + typename RmsNormTypeConfig::GammaDataType, + typename RmsNormTypeConfig::ComputeDataType, + typename RmsNormTypeConfig::YDataType, + typename RmsNormTypeConfig::InvRmsDataType, + typename Traits_::Shape, + Traits_::kPadN, + Traits_::kSaveInvRms, + Traits_::kTwoPass>; + + using OnePassPipeline = ck_tile::Rmsnorm2dFwdPipelineOnePass; + using TwoPassPipeline = ck_tile::Rmsnorm2dFwdPipelineTwoPass; + using Pipeline = std::conditional_t; + + using Kernel = ck_tile::Rmsnorm2dFwd; + + const dim3 grids = Kernel::GridSize(a); + constexpr dim3 blocks = Kernel::BlockSize(); + constexpr ck_tile::index_t kBlockPerCu = 1; + + auto kargs = Kernel::MakeKargs(a); + if(s.log_level_ > 0) + std::cout << ", " << Kernel::GetName() << std::flush; + + return ck_tile::launch_kernel( + s, ck_tile::make_kernel(Kernel{}, grids, blocks, 0, kargs)); +} diff --git a/example/ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.cpp new file mode 100644 index 0000000000..453d9a575d --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -0,0 +1,179 @@ +#include "ck_tile/host.hpp" +#include "rmsnorm2d_fwd.hpp" +#include + +// different threshold for different dtype +template +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +template <> +auto get_elimit() +{ + double rtol = 1e-2; + double atol = 1e-2; + return ck_tile::make_tuple(rtol, atol); +} + +auto create_args(int argc, char* argv[]) +{ + ck_tile::ArgParser arg_parser; + arg_parser.insert("m", "3328", "m dimension") + .insert("n", "4096", "n dimension") + .insert("stride", "-1", "stride per row, if -1 then equal to n") + .insert("e", "1e-5", "epsilon") + .insert("save_rms", "0", "save rms(invrms) or not. set to 1 in training case") + .insert("v", "1", "cpu validation or not") + .insert("kname", "1", "print kernel name or not") + .insert("prec", "fp16", "precision") + .insert("warmup", "5", "cold iter") + .insert("repeat", "20", "hot iter"); + + bool result = arg_parser.parse(argc, argv); + return std::make_tuple(result, arg_parser); +} + +template +bool run(const ck_tile::ArgParser& arg_parser) +{ + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + ck_tile::index_t stride = arg_parser.get_int("stride"); + if(stride < 0) + stride = n; + float epsilon = arg_parser.get_float("e"); + std::string data_type = arg_parser.get_str("prec"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + + assert(stride >= n); + + using TypeConfig = RmsNormTypeConfig; + + using XDataType = typename TypeConfig::XDataType; + using YDataType = typename TypeConfig::YDataType; + using GammaDataType = typename TypeConfig::GammaDataType; + + using InvRmsDataType = + std::conditional_t; + + using ComputeDataType = typename TypeConfig::ComputeDataType; + + // host verify + ck_tile::HostTensor x_host({m, n}, {stride, 1}); + ck_tile::HostTensor gamma_host({n}); + + ck_tile::HostTensor y_host_ref({m, n}, {stride, 1}); + ck_tile::HostTensor y_host_dev({m, n}, {stride, 1}); + + ck_tile::HostTensor invRms_host_ref({m}); + + ck_tile::FillUniformDistribution{-.5f, .5f}(x_host); + ck_tile::FillUniformDistribution{-.5f, .5f}(gamma_host); + + ck_tile::DeviceMem x_buf(x_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem gamma_buf(gamma_host.get_element_space_size_in_bytes()); + ck_tile::DeviceMem y_buf(y_host_dev.get_element_space_size_in_bytes()); + + x_buf.ToDevice(x_host.data()); + gamma_buf.ToDevice(gamma_host.data()); + + std::cout << "[" << data_type << "]" + << " m:" << m << ", n:" << n << ", stride:" << stride << std::flush; + + rmsnorm2d_fwd_traits traits{data_type, SaveRms}; + + rmsnorm2d_fwd_args args{x_buf.GetDeviceBuffer(), + gamma_buf.GetDeviceBuffer(), + y_buf.GetDeviceBuffer(), + nullptr, + epsilon, + m, + n, + stride}; + + float ave_time = rmsnorm2d_fwd( + traits, args, ck_tile::stream_config{nullptr, true, kname ? 1 : 0, warmup, repeat}); + + std::size_t num_byte = + sizeof(XDataType) * m * n + sizeof(GammaDataType) * n + sizeof(YDataType) * m * n; + + float gb_per_sec = num_byte / 1.E6 / ave_time; + std::cout << ", " << ave_time * 1.E3 << " us, " << gb_per_sec << " GB/s" << std::flush; + + bool pass = true; + + if(do_validation) + { + // reference + ck_tile::reference_rmsnorm2d_fwd( + x_host, gamma_host, y_host_ref, invRms_host_ref, epsilon); + + y_buf.FromDevice(y_host_dev.data()); + + auto [rtol, atol] = get_elimit(); + if(stride == n) + { + pass = ck_tile::check_err( + y_host_dev, y_host_ref, std::string("OUT Error: Incorrect results!"), rtol, atol); + } + else + { + for(int i_r = 0; i_r < m; i_r++) + { + std::vector y_host_dev_row(y_host_dev.begin() + i_r * stride, + y_host_dev.begin() + i_r * stride + n); + std::vector y_host_ref_row(y_host_ref.begin() + i_r * stride, + y_host_ref.begin() + i_r * stride + n); + pass &= ck_tile::check_err(y_host_dev_row, + y_host_ref_row, + std::string("OUT[") + std::to_string(i_r) + + std::string("] Error: Incorrect results!"), + rtol, + atol); + } + } + + std::cout << ", valid:" << (pass ? "y" : "n") << std::flush << std::endl; + } + + return pass; +} + +int main(int argc, char* argv[]) +{ + auto [result, arg_parser] = create_args(argc, argv); + if(!result) + return -1; + + const std::string data_type = arg_parser.get_str("prec"); + int save_rms = arg_parser.get_int("save_rms"); + if(data_type == "fp16" && save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "fp16" && !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && save_rms) + { + return run(arg_parser) ? 0 : -2; + } + else if(data_type == "bf16" && !save_rms) + { + return run(arg_parser) ? 0 : -2; + } + + return -3; +} diff --git a/example/ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.hpp b/example/ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.hpp new file mode 100644 index 0000000000..d557a7d4d8 --- /dev/null +++ b/example/ck_tile/06_rmsnorm2d/rmsnorm2d_fwd.hpp @@ -0,0 +1,117 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/kernel_launch.hpp" +#include "ck_tile/ops/rmsnorm2d.hpp" +#include + +template +struct RmsNormTypeConfig; + +template <> +struct RmsNormTypeConfig +{ + using XDataType = ck_tile::half_t; + using YDataType = ck_tile::half_t; + using GammaDataType = ck_tile::half_t; + using InvRmsDataType = ck_tile::half_t; + using ComputeDataType = float; +}; + +template <> +struct RmsNormTypeConfig +{ + using XDataType = ck_tile::bf16_t; + using YDataType = ck_tile::bf16_t; + using GammaDataType = ck_tile::bf16_t; + using InvRmsDataType = ck_tile::bf16_t; + using ComputeDataType = float; +}; + +// runtime args +struct rmsnorm2d_fwd_args : public ck_tile::Rmsnorm2dFwdHostArgs +{ +}; + +// this is used to pattern-match internl kernel implementation, not to instantiate kernel +template +struct rmsnorm2d_fwd_traits_ +{ + using DataType = ck_tile::remove_cvref_t; + + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= warpSize; + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % warpSize == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / warpSize; + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return total_warps * (warpSize / ThreadPerBlock_N_); + } + else + { + // static_assert(warpSize % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / warpSize); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(warpSize % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % warpSize == 0); + return ThreadPerBlock_N_ / warpSize; + } + }(); + + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; + static constexpr ck_tile::index_t Repeat_N = Repeat_N_; + + static constexpr ck_tile::index_t Block_M = Repeat_M_ * ThreadPerBlock_M_; + static constexpr ck_tile::index_t Block_N = Repeat_N_ * ThreadPerBlock_N_ * Vector_N_; + + static constexpr ck_tile::index_t Warp_M = ThreadPerBlock_M_ / BlockWarps_M; + static constexpr ck_tile::index_t Warp_N = ThreadPerBlock_N_ / BlockWarps_N * Vector_N_; + + using BlockTile = ck_tile::sequence; + using BlockWarps = ck_tile::sequence; + using WarpTile = ck_tile::sequence; + using Vector = ck_tile::sequence<1, Vector_N_>; + + using Shape = ck_tile::Rmsnorm2dShape; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +template +float rmsnorm2d_fwd_(const ck_tile::stream_config& s, rmsnorm2d_fwd_args a); + +// This is the public API, will be generated by script +struct rmsnorm2d_fwd_traits +{ + std::string data_type; + bool save_mean_var; +}; + +float rmsnorm2d_fwd(rmsnorm2d_fwd_traits, rmsnorm2d_fwd_args, const ck_tile::stream_config&); diff --git a/example/ck_tile/CMakeLists.txt b/example/ck_tile/CMakeLists.txt index ec4a175d35..22d51d1e10 100644 --- a/example/ck_tile/CMakeLists.txt +++ b/example/ck_tile/CMakeLists.txt @@ -7,3 +7,4 @@ add_subdirectory(02_layernorm2d) add_subdirectory(03_gemm) add_subdirectory(04_img2col) add_subdirectory(05_reduce) +add_subdirectory(06_rmsnorm2d) diff --git a/include/ck_tile/host.hpp b/include/ck_tile/host.hpp index dbc1f5d23a..8108b0f1ae 100644 --- a/include/ck_tile/host.hpp +++ b/include/ck_tile/host.hpp @@ -23,6 +23,7 @@ #include "ck_tile/host/reference/reference_im2col.hpp" #include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_reduce.hpp" +#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp" #include "ck_tile/host/reference/reference_softmax.hpp" #include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/timer.hpp" diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp new file mode 100644 index 0000000000..a3c392e2c5 --- /dev/null +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/host/host_tensor.hpp" + +namespace ck_tile { + +template +void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, + const HostTensor& gamma_n, + HostTensor& y_m_n, + HostTensor& invRms_m, + ComputeDataType epsilon) +{ + auto rmsnorm2d_fwd_func = [&](auto m) { + const int N = x_m_n.mDesc.get_lengths()[1]; + + ComputeDataType mean_square = 0; + ComputeDataType divisor = 0; + + for(int n = 0; n < N; ++n) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + mean_square += x * x; + } + + mean_square = mean_square / N; + divisor = ck_tile::type_convert(1) / ck_tile::sqrt(mean_square + epsilon); + + if constexpr(!std::is_same_v) + invRms_m(m) = ck_tile::type_convert(divisor); + + for(int n = 0; n < N; ++n) + { + ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); + ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); + auto y = x * divisor * gamma; + y_m_n(m, n) = ck_tile::type_convert(y); + } + }; + + make_ParallelTensorFunctor(rmsnorm2d_fwd_func, + invRms_m.mDesc.get_lengths()[0])(std::thread::hardware_concurrency()); +} +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d.hpp b/include/ck_tile/ops/rmsnorm2d.hpp new file mode 100644 index 0000000000..98c60f1b51 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d.hpp @@ -0,0 +1,12 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp" +#include "ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp new file mode 100644 index 0000000000..6207c24f43 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp @@ -0,0 +1,204 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common.hpp" + +namespace ck_tile { + +// host side args +struct Rmsnorm2dFwdHostArgs +{ + const void* p_x; + const void* p_gamma; + + void* p_y; + void* p_invRms; + + float epsilon; + + index_t m; + index_t n; + index_t stride; // row_stride +}; + +// TODO: Extract some type to wrapper class +template +struct Rmsnorm2dFwd +{ + using Pipeline = remove_cvref_t; + using Problem = typename Pipeline::Problem; + + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using InvRmsDataType = remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::kSaveInvRms; + + static constexpr index_t Block_M = Problem::BlockShape::Block_M; + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + static constexpr bool kPadM = false; // always no need to pad along M + static constexpr bool kPadN = Problem::kPadN; + static constexpr bool kTwoPass = Problem::kTwoPass; + + static constexpr index_t ThreadPerWarp_N = Problem::BlockShape::ThreadPerWarp_N; + static constexpr index_t Vector_N = Problem::BlockShape::Vector_N; + static constexpr index_t Repeat_N = Problem::BlockShape::Repeat_N; + + static constexpr auto I0 = number<0>{}; + static constexpr auto I1 = number<1>{}; + + struct Kargs + { + const void* p_x; + const void* p_gamma; + + void* p_y; + void* p_invRms; + + float epsilon; + + index_t m; + index_t n; + index_t stride; // row_stride + }; + using Hargs = Rmsnorm2dFwdHostArgs; + + CK_TILE_HOST static constexpr Kargs MakeKargs(const Hargs& hargs) + { + return Kargs{hargs.p_x, + hargs.p_gamma, + hargs.p_y, + hargs.p_invRms, + hargs.epsilon, + hargs.m, + hargs.n, + hargs.stride}; + } + + CK_TILE_HOST static constexpr auto GridSize(const Hargs& hargs) + { + return (hargs.m + Block_M - 1) / Block_M; + } + + CK_TILE_HOST static constexpr auto BlockSize() { return Problem::BlockShape::BlockSize; } + + // clang-format off + template struct t2s; + template <> struct t2s { static constexpr const char * name = "fp32"; }; + template <> struct t2s { static constexpr const char * name = "fp16"; }; + template <> struct t2s { static constexpr const char * name = "bf16"; }; + template <> struct t2s { static constexpr const char * name = "fp8"; }; + template <> struct t2s { static constexpr const char * name = "bf8"; }; + // clang-format on + + // in byte + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return Pipeline::GetSmemSize(); } + + CK_TILE_HOST static std::string GetName() + { + // clang-format off + using S_ = typename Problem::BlockShape; + auto surfix = [&] () { + std::string n; + if (kPadN) n += "_pn"; + if (kSaveInvRms) n += "_rms"; + if (kTwoPass) n += "_2p"; + return n; }(); + + #define _SS_ std::string + #define _TS_ std::to_string + return _SS_("rmsnorm2d_fwd_") + _SS_(t2s::name) + "_" + + _TS_(S_::Block_M) + "x" + _TS_(S_::Block_N) + "_" + _TS_(S_::WarpPerBlock_M) + "x" + _TS_(S_::WarpPerBlock_N) + "_" + + _TS_(S_::Warp_M) + "x" + _TS_(S_::Warp_N) + "_" + _TS_(S_::Vector_M) + "x" + _TS_(S_::Vector_N) + "_" + + _SS_(Pipeline::name) + surfix; + #undef _SS_ + #undef _TS_ + // clang-format on + } + + CK_TILE_DEVICE void operator()(Kargs kargs) const + { + const auto iM = get_block_id() * Block_M; + + const auto x_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_x), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + // NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will + // check the max count dynamically + const auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + const auto gamma_window = [&]() { + const auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_gamma), + make_tuple(kargs.n), + make_tuple(1), + number{}, + number<1>{}); + + const auto tmp2_ = + pad_tensor_view(tmp_, make_tuple(number{}), sequence{}); + + return make_tile_window(tmp2_, make_tuple(number{}), {0}); + }(); + + auto y_window = [&]() { + auto tmp_ = make_naive_tensor_view( + static_cast(kargs.p_y), + make_tuple(kargs.m, kargs.n), + make_tuple(kargs.stride, 1), + number{}, + number<1>{}); + + auto tmp2_ = pad_tensor_view( + tmp_, make_tuple(number{}, number{}), sequence{}); + return make_tile_window( + tmp2_, make_tuple(number{}, number{}), {iM, 0}); + }(); + + auto inv_rms_window = [&]() { + if constexpr(kSaveInvRms) + { + const auto inv_rms_m = [&]() { + const auto inv_rms_dram_naive = + make_naive_tensor_view_packed( + static_cast(kargs.p_invRms), + make_tuple(kargs.m), + number<1>{}); + + return pad_tensor_view( + inv_rms_dram_naive, make_tuple(number{}), sequence{}); + }(); + return make_tile_window(inv_rms_m, make_tuple(number{}), {iM}); + } + else + return make_null_tile_window(make_tuple(number{})); + }(); + + __shared__ char smem[GetSmemSize()]; + + Pipeline{}(x_window, + gamma_window, + y_window, + inv_rms_window, + static_cast(kargs.epsilon), + kargs.n, + smem); + } +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp new file mode 100644 index 0000000000..fb484a1069 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_shape.hpp @@ -0,0 +1,78 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" + +namespace ck_tile { +/* +// clang-format off + +4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector + + Block_N (Warp_N * WarpPerBlock_N * Repeat_N ) + +<----------------------< Repeat_N(2)>--------------------->+ + | | + +<-- -->+ + Warp_N + +--------------+--------------+--------------+--------------+----+----------------+ + Warp_M | wrap_0 | wrap_1 | | ^ ^ + +--------------+--------------+ | | + | wrap_2 | wrap_3 | | v + +--------------+--------------+--------------+--------------+----+ Block_M + | | | + + + | + | | | v + +--------------+--------------+--------------+--------------+ + + + each Warp-tile (e.g 16 thrd per row) + + Vector_N (contiguous pixels each thrd holds along N, or vector size) + +-----------+-----------+-----------+-----------+-----------+ + | thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M + +-----------+-----------+-----------+-----------+-----------+ + | thrd_16 | thrd_17 | thrd_18 | thrd_19 | ... + +-----------+-----------+-----------+-----------+-----------+ +// clang-format on +*/ +template + typename WarpPerBlock_, // num warps along seq + typename WarpTile_, // warp size, seq + typename Vector_, // contiguous pixels(vector size) along seq + index_t BlockSize_ = + warpSize* reduce_on_sequence(WarpPerBlock_{}, multiplies{}, number<1>{})> +struct Rmsnorm2dShape +{ + // block size + static constexpr index_t Block_M = BlockTile_::at(number<0>{}); + static constexpr index_t Block_N = BlockTile_::at(number<1>{}); + + // num warps along seq, within each block + static constexpr index_t WarpPerBlock_M = WarpPerBlock_::at(number<0>{}); + static constexpr index_t WarpPerBlock_N = WarpPerBlock_::at(number<1>{}); + + // warp size + static constexpr index_t Warp_M = WarpTile_::at(number<0>{}); + static constexpr index_t Warp_N = WarpTile_::at(number<1>{}); + + static_assert(Block_M % (WarpPerBlock_M * Warp_M) == 0); + static_assert(Block_N % (WarpPerBlock_N * Warp_N) == 0); + // repeat of each thread along seq + static constexpr index_t Repeat_M = Block_M / (WarpPerBlock_M * Warp_M); + static constexpr index_t Repeat_N = Block_N / (WarpPerBlock_N * Warp_N); + + // vector size along seq + static constexpr index_t Vector_M = Vector_::at(number<0>{}); + static constexpr index_t Vector_N = Vector_::at(number<1>{}); + + static_assert(Warp_M % Vector_M == 0); + static_assert(Warp_N % Vector_N == 0); + // num of threads along seq, within each warp + static constexpr index_t ThreadPerWarp_M = Warp_M / Vector_M; + static constexpr index_t ThreadPerWarp_N = Warp_N / Vector_N; + + static constexpr index_t BlockSize = BlockSize_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp new file mode 100644 index 0000000000..421e0d0c7e --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/reduce2d/block/block_reduce2d_problem.hpp" +#include "ck_tile/ops/reduce2d/block/block_reduce2d.hpp" + +namespace ck_tile { + +struct Rmsnorm2dFwdPipelineDefaultPolicy +{ + template + CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence<>, + tuple, + sequence>, + tuple, sequence<1, 2>>, + tuple, sequence<2, 2>>, + sequence<1, 1, 2, 2>, + sequence<0, 3, 0, 3>>{}); + } + template + CK_TILE_DEVICE static constexpr auto MakeGammaBlockTileDistribution() + { + using S = typename Problem::BlockShape; + + return make_static_tile_distribution( + tile_distribution_encoding< + sequence, + tuple>, + tuple, sequence<0, 1>>, + tuple, sequence<1, 2>>, + sequence<1, 1>, + sequence<0, 3>>{}); + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2d() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2d{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr auto GetBlockReduce2dCrossWarpSync() + { + using P_ = BlockReduce2dProblem; + return BlockReduce2dCrossWarpSync{}; + } + + template + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + if constexpr(Problem::kNeedCrossWarpSync) + { + using P_ = BlockReduce2dProblem; + + using block_reduce2d = BlockReduce2d; + using x_block_tile = + decltype(make_static_distributed_tensor( + MakeXBlockTileDistribution())); + using y_block_tile = decltype(block_reduce2d::template MakeYBlockTile()); + + return GetBlockReduce2dCrossWarpSync().template GetSmemSize(); + } + else + { + return 1; // zero size arrays are an extension + } + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp new file mode 100644 index 0000000000..540ac1750c --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp @@ -0,0 +1,100 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct Rmsnorm2dFwdPipelineOnePass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using InvRmsDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::kSaveInvRms; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr"; // block per row + else + return "wpr"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const GammaWindow& gamma_window_, + YWindow& y_window, + InvRmsWindow& inv_rms_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + const auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + const auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + + auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; }; + auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + const auto x = load_tile(x_window); + // load gamma (TODO: support no gamma?) + const auto gamma = load_tile(gamma_window); + + // compute mean square each-thread->cross-lane->cross-warp + auto square_sum = block_reduce2d(x, 0, reduce_square_sum_func); + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + + // compute inv-rms + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + }, + square_sum); + + if constexpr(kSaveInvRms) + store_tile(inv_rms_window, cast_tile(inv_rms)); + + // rmsnorm computation + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms_[i_idx] * gamma_; + + y(idx) = type_convert(y_); + }); + store_tile(y_window, y); + } +}; +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp new file mode 100644 index 0000000000..87cab34631 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp @@ -0,0 +1,36 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core/utility/type_traits.hpp" + +namespace ck_tile { + +template +struct Rmsnorm2dFwdPipelineProblem +{ + using XDataType = remove_cvref_t; + using GammaDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + using InvRmsDataType = remove_cvref_t; + using BlockShape = remove_cvref_t; + + static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1; + static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1; + + static constexpr bool kPadN = kPadN_; + static constexpr bool kSaveInvRms = kSaveInvRms_; + static constexpr bool kTwoPass = kTwoPass_; +}; + +} // namespace ck_tile diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp new file mode 100644 index 0000000000..1fa8ed9738 --- /dev/null +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp @@ -0,0 +1,131 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp" +#include +#include + +namespace ck_tile { + +template +struct Rmsnorm2dFwdPipelineTwoPass +{ + using Problem = ck_tile::remove_cvref_t; + using Policy = ck_tile::remove_cvref_t; + + using XDataType = ck_tile::remove_cvref_t; + using GammaDataType = ck_tile::remove_cvref_t; + using ComputeDataType = ck_tile::remove_cvref_t; + using YDataType = ck_tile::remove_cvref_t; + using InvRmsDataType = ck_tile::remove_cvref_t; + + static constexpr bool kHasGamma = !std::is_same_v; + static constexpr bool kSaveInvRms = Problem::kSaveInvRms; + + static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; + static constexpr bool kPadM = false; // TODO - BlockRmsnorm2dFwdProblem::kPadM + static constexpr bool kPadN = Problem::kPadN; + + static constexpr const char* name = []() { + if constexpr(kNeedCrossWarpSync) + return "bpr"; // block per row + else + return "wpr"; // warp per row + }(); + + CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() + { + return Policy::template GetSmemSize(); + } + + template + CK_TILE_DEVICE auto operator()(const XWindow& x_window_, + const GammaWindow& gamma_window_, + YWindow& y_window, + InvRmsWindow& inv_rms_window, + ComputeDataType epsilon, + ck_tile::index_t row_size, + void* smem) const + { + auto x_window = + make_tile_window(x_window_, Policy::template MakeXBlockTileDistribution()); + auto gamma_window = make_tile_window( + gamma_window_, Policy::template MakeGammaBlockTileDistribution()); + + // Problem::BlockShape + static constexpr index_t Block_N = Problem::BlockShape::Block_N; + index_t num_n_tile_iteration = + __builtin_amdgcn_readfirstlane(integer_divide_ceil(row_size, Block_N)); + + auto reduce_square_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1 * v1; }; + auto reduce_sum_func = [](const auto& v0, const auto& v1) { return v0 + v1; }; + auto block_reduce2d = Policy::template GetBlockReduce2d(); + auto block_reduce2d_sync = Policy::template GetBlockReduce2dSync(); + auto block_reduce2d_cross_warp_sync = + Policy::template GetBlockReduce2dCrossWarpSync(); + + using XTensorType = decltype(load_tile(x_window)); + auto square_sum = block_reduce2d.template MakeYBlockTile(); + set_tile(square_sum, 0); + + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + block_reduce2d(x, square_sum, reduce_square_sum_func); + move_tile_window(x_window, {0, Block_N}); + } + + block_reduce2d_sync(square_sum, reduce_sum_func); + block_reduce2d_cross_warp_sync(square_sum, smem, reduce_sum_func); + + // compute inv-rms + auto inv_rms = tile_elementwise_in( + [&](const auto& v_) { + return type_convert(1.0f) / (sqrt(v_ / row_size + epsilon)); + }, + square_sum); + + if constexpr(kSaveInvRms) + store_tile(inv_rms_window, cast_tile(inv_rms)); + + // reverse read x to reuse cache + ck_tile::index_t stride_to_right_most_window = + row_size % Block_N == 0 ? row_size - Block_N : row_size - row_size % Block_N; + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(gamma_window, {stride_to_right_most_window}); + move_tile_window(y_window, {0, stride_to_right_most_window}); + + // rmsnorm computation + for(int iN = __builtin_amdgcn_readfirstlane(0); iN < num_n_tile_iteration; ++iN) + { + const auto x = load_tile(x_window); + // load gamma/beta (TODO: support no gamma/beta?) + const auto gamma = load_tile(gamma_window); + + auto y = make_static_distributed_tensor(x.get_tile_distribution()); + + sweep_tile(y, [&, inv_rms_ = inv_rms](auto idx) { + constexpr auto i_idx = make_tuple(idx[number<0>{}]); + constexpr auto j_idx = make_tuple(idx[number<1>{}]); + + const auto gamma_ = type_convert(gamma[j_idx]); + + const auto x_ = type_convert(x[idx]); + auto y_ = x_ * inv_rms_[i_idx] * gamma_; + + y(idx) = type_convert(y_); + }); + + store_tile(y_window, y); + + move_tile_window(x_window, {0, -Block_N}); + move_tile_window(gamma_window, {-Block_N}); + move_tile_window(y_window, {0, -Block_N}); + } + } +}; +} // namespace ck_tile