Fix redundant cast in model sensitive rmsnorm (#3681)

* Fix redundant cast

* Fix linting
This commit is contained in:
MHYangAMD
2026-01-30 10:52:19 +08:00
committed by GitHub
parent 83b6155354
commit 6ff0737843

View File

@@ -181,12 +181,10 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 =
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
type_convert<ComputeDataType>(tmp0) * gamma_);
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
rmsn(idx) = rmsn_;
const auto tmp = acc[idx] * inv_rms_[i_idx];
const auto tmp_bf16 = float_to_bf16<bf16_rounding_mode::standard>(tmp);
const auto rmsn_ = type_convert<ComputeDataType>(tmp_bf16) * gamma_;
rmsn(idx) = rmsn_;
}
else
{