From 6a423df526f839dd13a1fcec36e08c1394afb550 Mon Sep 17 00:00:00 2001 From: ClementLinCF <162283536+ClementLinCF@users.noreply.github.com> Date: Tue, 14 Oct 2025 02:52:37 +0800 Subject: [PATCH] [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540) * [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang Co-authored-by: illsilin_amdeng Co-authored-by: ThomasNing [ROCm/composable_kernel commit: e1b0bdfbfa92f47006fdbced627c7470eacdea2b] --- example/ck_tile/02_layernorm2d/generate.py | 33 +++++ example/ck_tile/10_rmsnorm2d/generate.py | 39 +++++- .../ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp | 47 +++++-- .../ck_tile/10_rmsnorm2d/script/smoke_test.sh | 124 +++++++++++------- .../reference/reference_rmsnorm2d_fwd.hpp | 31 ++++- .../ops/reduce/block/block_reduce2d.hpp | 4 +- ...rm2d_fwd_pipeline_model_sensitive_pass.hpp | 6 +- 7 files changed, 217 insertions(+), 67 deletions(-) diff --git a/example/ck_tile/02_layernorm2d/generate.py b/example/ck_tile/02_layernorm2d/generate.py index b7512b2999..5f589db8d0 100644 --- a/example/ck_tile/02_layernorm2d/generate.py +++ b/example/ck_tile/02_layernorm2d/generate.py @@ -75,6 +75,39 @@ struct layernorm2d_fwd_traits_ using SmoothScaleDataType = ck_tile::remove_cvref_t; using YScaleDataType = ck_tile::remove_cvref_t; + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps; + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; diff --git a/example/ck_tile/10_rmsnorm2d/generate.py b/example/ck_tile/10_rmsnorm2d/generate.py index 0e948322a2..75d7abd0ad 100644 --- a/example/ck_tile/10_rmsnorm2d/generate.py +++ b/example/ck_tile/10_rmsnorm2d/generate.py @@ -75,6 +75,39 @@ struct rmsnorm2d_fwd_traits_ using YScaleDataType = ck_tile::remove_cvref_t; using UnquantYDataType = ck_tile::remove_cvref_t; + static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size(); + static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0); + static constexpr ck_tile::index_t total_warps = + (ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size(); + + // num of warps along m + static constexpr ck_tile::index_t BlockWarps_M = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return total_warps; + } + else + { + // static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0); + return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size()); + } + }(); + + // num of warps along n + static constexpr ck_tile::index_t BlockWarps_N = []() { + if constexpr(is_warp_per_row) + { + static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0); + return 1; + } + else + { + static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0); + return ThreadPerBlock_N_ / ck_tile::get_warp_size(); + } + }(); + static constexpr ck_tile::index_t Repeat_M = Repeat_M_; static constexpr ck_tile::index_t Repeat_N = Repeat_N_; @@ -605,15 +638,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t, h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1), h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)] - } + } } - + total_blob = list() for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive current_trait_dict = h_trait_dicts[model_sensitive_flag] for hs_key in current_trait_dict: - hs = current_trait_dict[hs_key] + hs = current_trait_dict[hs_key] current_n = hs_key for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list): prec_i, prec_o = dtype.split(',') diff --git a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp index 6e2664e9ba..8518b5ddc7 100644 --- a/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp +++ b/example/ck_tile/10_rmsnorm2d/rmsnorm2d_fwd.cpp @@ -70,16 +70,16 @@ template bool run(const ck_tile::ArgParser& arg_parser) { - ck_tile::index_t m = arg_parser.get_int("m"); - ck_tile::index_t n = arg_parser.get_int("n"); - float epsilon = arg_parser.get_float("e"); - int kname = arg_parser.get_int("kname"); - int do_validation = arg_parser.get_int("v"); - int fused_add = arg_parser.get_int("fadd"); - int fused_quant = arg_parser.get_int("fquant"); - int warmup = arg_parser.get_int("warmup"); - int repeat = arg_parser.get_int("repeat"); - const int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); + ck_tile::index_t m = arg_parser.get_int("m"); + ck_tile::index_t n = arg_parser.get_int("n"); + float epsilon = arg_parser.get_float("e"); + int kname = arg_parser.get_int("kname"); + int do_validation = arg_parser.get_int("v"); + int fused_add = arg_parser.get_int("fadd"); + int fused_quant = arg_parser.get_int("fquant"); + int warmup = arg_parser.get_int("warmup"); + int repeat = arg_parser.get_int("repeat"); + int use_model_sensitive_rmsnorm = arg_parser.get_int("s"); ck_tile::index_t x_stride = arg_parser.get_int("x_stride"); if(x_stride < 0) @@ -196,6 +196,11 @@ bool run(const ck_tile::ArgParser& arg_parser) return base_str; }(); + if(n > 8192) + { + use_model_sensitive_rmsnorm = 0; + } + std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride << ", xr_stride:" << xr_stride << ", y_stride:" << y_stride << ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush; @@ -297,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser) const int N = acc_.mDesc.get_lengths()[1]; for(int n_ = 0; n_ < N; ++n_) { - o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); + o_unquant_(m_, n_) = ck_tile::type_convert(acc_(m_, n_)); } dquant_functor(m_, o_, acc_); @@ -316,7 +321,8 @@ bool run(const ck_tile::ArgParser& arg_parser) invRms_host_ref, unquant_y_host_ref, epsilon, - default_and_dquant_functor); + default_and_dquant_functor, + use_model_sensitive_rmsnorm); } else { @@ -331,7 +337,8 @@ bool run(const ck_tile::ArgParser& arg_parser) invRms_host_ref, unquant_y_host_ref, epsilon, - dquant_functor); + dquant_functor, + use_model_sensitive_rmsnorm); } } else @@ -343,7 +350,14 @@ bool run(const ck_tile::ArgParser& arg_parser) YDataType, InvRmsDataType, ck_tile::null_type>( - x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon); + x_host, + gamma_host, + y_host_ref, + invRms_host_ref, + unquant_y_null, + epsilon, + ck_tile::reference_rmsnorm2d_default_epilogue{}, + use_model_sensitive_rmsnorm); } y_buf.FromDevice(y_host_dev.data()); @@ -354,6 +368,11 @@ bool run(const ck_tile::ArgParser& arg_parser) y_residual_buf.FromDevice(y_residual_host_dev.data()); } + if constexpr(SaveUnquant) + { + unquant_y_buf.FromDevice(unquant_y_host_dev.data()); + } + auto [rtol, atol] = get_elimit(); if(x_stride == n) { diff --git a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh index 1c79dafadd..3a0f7dbb66 100755 --- a/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh +++ b/example/ck_tile/10_rmsnorm2d/script/smoke_test.sh @@ -1,49 +1,85 @@ -#!/bin/sh +#!/bin/bash + EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)" -for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\ - "-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do -for pr_i in "fp16" "bf16" ; do -for fadd in "0" "1"; do -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm -for s in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024 -# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096 -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192 -done -done -done +total=0 +valid=0 + +run_case() { + cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7" + echo "[CMD] $cmd" + output=$($cmd 2>&1) + echo "$output" + if echo "$output" | grep -q "valid:y"; then + valid=$((valid + 1)) + fi + total=$((total + 1)) +} + +fquant_list=( + "" + "-fquant=1 -prec_o=int8" + "-fquant=2 -prec_o=int8" + "-fquant=1 -prec_o=fp8" + "-fquant=2 -prec_o=fp8" + "-fquant=1 -prec_o=int8 -save_unquant=1" + "-fquant=2 -prec_o=int8 -save_unquant=1" + "-fquant=1 -prec_o=fp8 -save_unquant=1" + "-fquant=2 -prec_o=fp8 -save_unquant=1" +) + +m_n_list=( + "99 13" "17 16" "1 100" "4 128" "80 127" + "7 599" "19 512" "11 510" "91 636" + "31 1024" "8 1501" "3 1826" "5 2040" + "7 2734" "1 3182" "9 4096" "3 8192" +) + +### Add special stride test ### +m_n_stride_list=( + "22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256" + "33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000" + "171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818" + "12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800" + "100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812" + "64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004" +) + +for fquant in "${fquant_list[@]}"; do + for pr_i in "fp16" "bf16"; do + for fadd in "0" "1"; do + for s in "0" "1"; do + for pair in "${m_n_list[@]}"; do + m=$(echo $pair | cut -d ' ' -f1) + n=$(echo $pair | cut -d ' ' -f2) + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "" + done + + ### Running tests with stride ### + for triple in "${m_n_stride_list[@]}"; do + m=$(echo $triple | cut -d ' ' -f1) + n=$(echo $triple | cut -d ' ' -f2) + stride_args=$(echo $triple | cut -d ' ' -f3-) + run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args" + done + done + done + done done -# The following cases uses two pass pipeline which doesn't support quant epilogue. -for fquant in "" -for pr_i in "fp16" "bf16" ; do -for fadd in "0" "1"; do -# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm -for s in "0" "1"; do -$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547 -#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134 -done -done -done +# Special two-pass only +for pr_i in "fp16" "bf16"; do + for fadd in "0" "1"; do + for s in "0" "1"; do + run_case "$pr_i" "$fadd" "$s" "" "1" "10547" "" + done + done done + +# Summary +echo "==============================" +echo "Total cases: $total" +echo "Valid cases: $valid" +accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}") +echo "Accuracy: $accuracy%" +echo "==============================" diff --git a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp index 070168b51d..424fff4470 100644 --- a/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp +++ b/include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp @@ -5,6 +5,7 @@ #include "ck_tile/core.hpp" #include "ck_tile/host/host_tensor.hpp" +#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp" namespace ck_tile { @@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, HostTensor& invRms_m, HostTensor& unquant_y_m_n, ComputeDataType epsilon, - Epilogue epilogue_functor = {}) + Epilogue epilogue_functor = {}, + const int use_model_sensitive_rmsnorm = + static_cast(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) { auto rmsnorm2d_fwd_func = [&](auto m) { const int N = x_m_n.mDesc.get_lengths()[1]; @@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, { ComputeDataType x = ck_tile::type_convert(x_m_n(m, n)); ComputeDataType gamma = ck_tile::type_convert(gamma_n(n)); - acc(m, n) = x * divisor * gamma; + if(use_model_sensitive_rmsnorm == + static_cast( + Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model + { + acc(m, n) = x * divisor * gamma; + } + else if(use_model_sensitive_rmsnorm == + static_cast(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model + { + if constexpr(std::is_same_v) + { + const auto tmp0 = float_to_bf16(x * divisor); + const auto tmp1 = float_to_bf16( + type_convert(tmp0) * gamma); + const auto rmsn_ = type_convert(tmp1); + acc(m, n) = rmsn_; + } + else + { + const auto tmp = type_convert(x * divisor); + const auto rmsn_ = type_convert(tmp) * gamma; + acc(m, n) = rmsn_; + } + } } if constexpr(!std::is_same_v) @@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor& x_m_n, make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])( std::thread::hardware_concurrency()); } + } // namespace ck_tile diff --git a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp index b72657b785..b97a66a3ec 100644 --- a/include/ck_tile/ops/reduce/block/block_reduce2d.hpp +++ b/include/ck_tile/ops/reduce/block/block_reduce2d.hpp @@ -400,11 +400,13 @@ struct BlockReduce2dTreeCrossWarpSync block_sync_lds(); // We let each warp holds a duplication to do reduction. + const index_t local_warp_id = warp_id / num_reduce_warps; + const index_t local_smem_os = local_warp_id * num_reduce_warps; static_for<0, thread_buf_size, 1>{}([&](auto i) { DataType v = 0; if(lane_id < num_reduce_warps) { - v = smem_ptr[lane_id + i * num_warps]; + v = smem_ptr[i * num_warps + local_smem_os + lane_id]; } // cross-lane reduce for replication diff --git a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp index c5923ba10d..1d5467b459 100644 --- a/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp +++ b/include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_model_sensitive_pass.hpp @@ -146,7 +146,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass // compute mean square each-thread->cross-lane->cross-warp auto square_sum = block_reduce2d.template MakeYBlockTile(); set_tile(square_sum, 0); - if constexpr(Problem::BlockShape::Vector_N % 2 == 0) + if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0) { sweep_tile( acc, @@ -179,7 +179,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass const auto gamma_ = type_convert(gamma[j_idx]); - if constexpr(std::is_same_v) + if constexpr(std::is_same_v) { const auto tmp0 = float_to_bf16(acc[idx] * inv_rms_[i_idx]); @@ -190,7 +190,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass } else { - const auto tmp = type_convert(acc[idx] * inv_rms_[i_idx]); + const auto tmp = type_convert(acc[idx] * inv_rms_[i_idx]); const auto rmsn_ = type_convert(tmp) * gamma_; rmsn(idx) = rmsn_; }