mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)
* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang <mengyang@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -146,7 +146,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
// compute mean square each-thread->cross-lane->cross-warp
|
||||
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
|
||||
set_tile(square_sum, 0);
|
||||
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
|
||||
if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0)
|
||||
{
|
||||
sweep_tile(
|
||||
acc,
|
||||
@@ -179,7 +179,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
|
||||
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
|
||||
|
||||
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
|
||||
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
const auto tmp0 =
|
||||
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
|
||||
@@ -190,7 +190,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto tmp = type_convert<XDataType>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
|
||||
rmsn(idx) = rmsn_;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user