mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
Fix redundant cast in model sensitive rmsnorm (#3681)
* Fix redundant cast * Fix linting
This commit is contained in:
@@ -181,12 +181,10 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
|
||||
|
||||
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
|
||||
{
|
||||
const auto tmp0 =
|
||||
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
|
||||
const auto tmp1 = float_to_bf16<bf16_rounding_mode::standard>(
|
||||
type_convert<ComputeDataType>(tmp0) * gamma_);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp1);
|
||||
rmsn(idx) = rmsn_;
|
||||
const auto tmp = acc[idx] * inv_rms_[i_idx];
|
||||
const auto tmp_bf16 = float_to_bf16<bf16_rounding_mode::standard>(tmp);
|
||||
const auto rmsn_ = type_convert<ComputeDataType>(tmp_bf16) * gamma_;
|
||||
rmsn(idx) = rmsn_;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user