Merge commit 'e1b0bdfbfa92f47006fdbced627c7470eacdea2b' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-13 19:10:56 +00:00
parent 713691609c
commit 63d907604b
7 changed files with 217 additions and 67 deletions

View File

@@ -75,6 +75,39 @@ struct layernorm2d_fwd_traits_
using SmoothScaleDataType = ck_tile::remove_cvref_t<SmoothScaleDataType_>;
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps;
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;

View File

@@ -75,6 +75,39 @@ struct rmsnorm2d_fwd_traits_
using YScaleDataType = ck_tile::remove_cvref_t<YScaleDataType_>;
using UnquantYDataType = ck_tile::remove_cvref_t<UnquantYDataType_>;
static constexpr bool is_warp_per_row = ThreadPerBlock_N_ <= ck_tile::get_warp_size();
static_assert((ThreadPerBlock_M_ * ThreadPerBlock_N_) % ck_tile::get_warp_size() == 0);
static constexpr ck_tile::index_t total_warps =
(ThreadPerBlock_M_ * ThreadPerBlock_N_) / ck_tile::get_warp_size();
// num of warps along m
static constexpr ck_tile::index_t BlockWarps_M = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return total_warps;
}
else
{
// static_assert(ck_tile::get_warp_size() % ThreadPerBlock_M_ == 0);
return total_warps / (ThreadPerBlock_N_ / ck_tile::get_warp_size());
}
}();
// num of warps along n
static constexpr ck_tile::index_t BlockWarps_N = []() {
if constexpr(is_warp_per_row)
{
static_assert(ck_tile::get_warp_size() % ThreadPerBlock_N_ == 0);
return 1;
}
else
{
static_assert(ThreadPerBlock_N_ % ck_tile::get_warp_size() == 0);
return ThreadPerBlock_N_ / ck_tile::get_warp_size();
}
}();
static constexpr ck_tile::index_t Repeat_M = Repeat_M_;
static constexpr ck_tile::index_t Repeat_N = Repeat_N_;
@@ -605,15 +638,15 @@ float rmsnorm2d_fwd(rmsnorm2d_fwd_traits t,
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1, 256, 4, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 12, 1, 256, 2, True, False, False, True, 0, 0, 1),
h_traits('x', 'y', 'xs', 'ys', 'uqy', 1, 4, 1,1024, 1, True, False, False, True, 0, 0, 1)]
}
}
}
total_blob = list()
for model_sensitive_flag in [0, 1]: # 0: default; 1: model sensitive
current_trait_dict = h_trait_dicts[model_sensitive_flag]
for hs_key in current_trait_dict:
hs = current_trait_dict[hs_key]
hs = current_trait_dict[hs_key]
current_n = hs_key
for dtype, scale_type, fused_add, fused_quant, save_unquant in itertools.product(dtype_list, scale_list, fused_add_list, fused_sweep_list, bool_list):
prec_i, prec_o = dtype.split(',')

View File

@@ -70,16 +70,16 @@ template <typename InDataType,
bool SaveUnquant>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
const int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
ck_tile::index_t m = arg_parser.get_int("m");
ck_tile::index_t n = arg_parser.get_int("n");
float epsilon = arg_parser.get_float("e");
int kname = arg_parser.get_int("kname");
int do_validation = arg_parser.get_int("v");
int fused_add = arg_parser.get_int("fadd");
int fused_quant = arg_parser.get_int("fquant");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
int use_model_sensitive_rmsnorm = arg_parser.get_int("s");
ck_tile::index_t x_stride = arg_parser.get_int("x_stride");
if(x_stride < 0)
@@ -196,6 +196,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
return base_str;
}();
if(n > 8192)
{
use_model_sensitive_rmsnorm = 0;
}
std::cout << "[" << prec_str << "]" << " m:" << m << ", n:" << n << ", x_stride:" << x_stride
<< ", xr_stride:" << xr_stride << ", y_stride:" << y_stride
<< ", yr_stride:" << yr_stride << ", s:" << use_model_sensitive_rmsnorm << std::flush;
@@ -297,7 +302,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
const int N = acc_.mDesc.get_lengths()[1];
for(int n_ = 0; n_ < N; ++n_)
{
o_unquant_(m_, n_) = ck_tile::type_convert<OutDataType>(acc_(m_, n_));
o_unquant_(m_, n_) = ck_tile::type_convert<UnquantYDataType>(acc_(m_, n_));
}
dquant_functor(m_, o_, acc_);
@@ -316,7 +321,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
invRms_host_ref,
unquant_y_host_ref,
epsilon,
default_and_dquant_functor);
default_and_dquant_functor,
use_model_sensitive_rmsnorm);
}
else
{
@@ -331,7 +337,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
invRms_host_ref,
unquant_y_host_ref,
epsilon,
dquant_functor);
dquant_functor,
use_model_sensitive_rmsnorm);
}
}
else
@@ -343,7 +350,14 @@ bool run(const ck_tile::ArgParser& arg_parser)
YDataType,
InvRmsDataType,
ck_tile::null_type>(
x_host, gamma_host, y_host_ref, invRms_host_ref, unquant_y_null, epsilon);
x_host,
gamma_host,
y_host_ref,
invRms_host_ref,
unquant_y_null,
epsilon,
ck_tile::reference_rmsnorm2d_default_epilogue{},
use_model_sensitive_rmsnorm);
}
y_buf.FromDevice(y_host_dev.data());
@@ -354,6 +368,11 @@ bool run(const ck_tile::ArgParser& arg_parser)
y_residual_buf.FromDevice(y_residual_host_dev.data());
}
if constexpr(SaveUnquant)
{
unquant_y_buf.FromDevice(unquant_y_host_dev.data());
}
auto [rtol, atol] = get_elimit<YDataType>();
if(x_stride == n)
{

View File

@@ -1,49 +1,85 @@
#!/bin/sh
#!/bin/bash
EXE="$(find . -name tile_rmsnorm2d_fwd -type f | head -n 1)"
for fquant in "" "-fquant=1 -prec_o=int8" "-fquant=2 -prec_o=int8" "-fquant=1 -prec_o=fp8" "-fquant=2 -prec_o=fp8"\
"-fquant=1 -prec_o=int8 -save_unquant=1" "-fquant=2 -prec_o=int8 -save_unquant=1" "-fquant=1 -prec_o=fp8 -save_unquant=1" "-fquant=2 -prec_o=fp8 -save_unquant=1"; do
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=99 -n=13
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=17 -n=16
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=100
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=4 -n=128
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=80 -n=127
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=22 -n=255 -stride=256
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=599
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=19 -n=512
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=33 -n=313 -stride=1000
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=11 -n=510
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=171 -n=676 -stride=818
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=91 -n=636
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=12 -n=768 -stride=800
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=100 -n=766 -stride=812
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=31 -n=1024
# $EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=64 -n=1000 -stride=1004
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=8 -n=1501
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=1826
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=5 -n=2040
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=7 -n=2734
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=3182
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=9 -n=4096
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=3 -n=8192
done
done
done
total=0
valid=0
run_case() {
cmd="$EXE -prec_i=$1 -fadd=$2 -s=$3 $4 -m=$5 -n=$6 $7"
echo "[CMD] $cmd"
output=$($cmd 2>&1)
echo "$output"
if echo "$output" | grep -q "valid:y"; then
valid=$((valid + 1))
fi
total=$((total + 1))
}
fquant_list=(
""
"-fquant=1 -prec_o=int8"
"-fquant=2 -prec_o=int8"
"-fquant=1 -prec_o=fp8"
"-fquant=2 -prec_o=fp8"
"-fquant=1 -prec_o=int8 -save_unquant=1"
"-fquant=2 -prec_o=int8 -save_unquant=1"
"-fquant=1 -prec_o=fp8 -save_unquant=1"
"-fquant=2 -prec_o=fp8 -save_unquant=1"
)
m_n_list=(
"99 13" "17 16" "1 100" "4 128" "80 127"
"7 599" "19 512" "11 510" "91 636"
"31 1024" "8 1501" "3 1826" "5 2040"
"7 2734" "1 3182" "9 4096" "3 8192"
)
### Add special stride test ###
m_n_stride_list=(
"22 255 -x_stride=256 -xr_stride=256 -y_stride=256 -yr_stride=256"
"33 313 -x_stride=1000 -xr_stride=1000 -y_stride=1000 -yr_stride=1000"
"171 676 -x_stride=818 -xr_stride=818 -y_stride=818 -yr_stride=818"
"12 768 -x_stride=800 -xr_stride=800 -y_stride=800 -yr_stride=800"
"100 766 -x_stride=812 -xr_stride=812 -y_stride=812 -yr_stride=812"
"64 1000 -x_stride=1004 -xr_stride=1004 -y_stride=1004 -yr_stride=1004"
)
for fquant in "${fquant_list[@]}"; do
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
for pair in "${m_n_list[@]}"; do
m=$(echo $pair | cut -d ' ' -f1)
n=$(echo $pair | cut -d ' ' -f2)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" ""
done
### Running tests with stride ###
for triple in "${m_n_stride_list[@]}"; do
m=$(echo $triple | cut -d ' ' -f1)
n=$(echo $triple | cut -d ' ' -f2)
stride_args=$(echo $triple | cut -d ' ' -f3-)
run_case "$pr_i" "$fadd" "$s" "$fquant" "$m" "$n" "$stride_args"
done
done
done
done
done
# The following cases uses two pass pipeline which doesn't support quant epilogue.
for fquant in ""
for pr_i in "fp16" "bf16" ; do
for fadd in "0" "1"; do
# 0: for no specific RMSNorm; 1: for T-5 like RMSNorm
for s in "0" "1"; do
$EXE -prec_i=$pr_i -fadd=$fadd -s=$s $fquant -m=1 -n=10547
#$EXE -prec_i=$pr_i -fadd=$fadd $fquant -m=3 -n=17134
done
done
done
# Special two-pass only
for pr_i in "fp16" "bf16"; do
for fadd in "0" "1"; do
for s in "0" "1"; do
run_case "$pr_i" "$fadd" "$s" "" "1" "10547" ""
done
done
done
# Summary
echo "=============================="
echo "Total cases: $total"
echo "Valid cases: $valid"
accuracy=$(awk "BEGIN {printf \"%.2f\", ($valid / $total) * 100}")
echo "Accuracy: $accuracy%"
echo "=============================="

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace ck_tile {
@@ -43,7 +44,9 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
HostTensor<InvRmsDataType>& invRms_m,
HostTensor<UnquantYDataType>& unquant_y_m_n,
ComputeDataType epsilon,
Epilogue epilogue_functor = {})
Epilogue epilogue_functor = {},
const int use_model_sensitive_rmsnorm =
static_cast<int>(Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL))
{
auto rmsnorm2d_fwd_func = [&](auto m) {
const int N = x_m_n.mDesc.get_lengths()[1];
@@ -68,7 +71,30 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
{
ComputeDataType x = ck_tile::type_convert<ComputeDataType>(x_m_n(m, n));
ComputeDataType gamma = ck_tile::type_convert<ComputeDataType>(gamma_n(n));
acc(m, n) = x * divisor * gamma;
if(use_model_sensitive_rmsnorm ==
static_cast<int>(
Rmsnorm2dSensitiveEnum::NO_SPECIFIC_MODEL)) // 0: for no specific model
{
acc(m, n) = x * divisor * gamma;
}
else if(use_model_sensitive_rmsnorm ==
static_cast<int>(Rmsnorm2dSensitiveEnum::T5_MODEL_LIKE)) // 1: for T5-like model
{
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 = float_to_bf16<bf16_rounding_mode::standard>(x * divisor);
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
type_convert<ComputeDataType>(tmp0) * gamma);
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
acc(m, n) = rmsn_;
}
else
{
const auto tmp = type_convert<XDataType>(x * divisor);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma;
acc(m, n) = rmsn_;
}
}
}
if constexpr(!std::is_same_v<UnquantYDataType, ck_tile::null_type>)
@@ -84,4 +110,5 @@ void reference_rmsnorm2d_fwd(const HostTensor<XDataType>& x_m_n,
make_ParallelTensorFunctor(rmsnorm2d_fwd_func, invRms_m.mDesc.get_lengths()[0])(
std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -400,11 +400,13 @@ struct BlockReduce2dTreeCrossWarpSync
block_sync_lds();
// We let each warp holds a duplication to do reduction.
const index_t local_warp_id = warp_id / num_reduce_warps;
const index_t local_smem_os = local_warp_id * num_reduce_warps;
static_for<0, thread_buf_size, 1>{}([&](auto i) {
DataType v = 0;
if(lane_id < num_reduce_warps)
{
v = smem_ptr[lane_id + i * num_warps];
v = smem_ptr[i * num_warps + local_smem_os + lane_id];
}
// cross-lane reduce for replication

View File

@@ -146,7 +146,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
set_tile(square_sum, 0);
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0)
{
sweep_tile(
acc,
@@ -179,7 +179,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 =
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
@@ -190,7 +190,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
}
else
{
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
const auto tmp = type_convert<XDataType>(acc[idx] * inv_rms_[i_idx]);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
rmsn(idx) = rmsn_;
}