mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-08 15:30:23 +00:00
Fix bug of rmsnorm
This commit is contained in:
@@ -134,10 +134,8 @@ struct Rmsnorm2dFwd
|
||||
number<Vector_N>{},
|
||||
number<1>{});
|
||||
|
||||
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel will
|
||||
// check the max count dynamically
|
||||
const auto tmp2_ = pad_tensor_view(
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<false, false>{});
|
||||
tmp_, make_tuple(number<Block_M>{}, number<Block_N>{}), sequence<kPadM, kPadN>{});
|
||||
return make_tile_window(
|
||||
tmp2_, make_tuple(number<Block_M>{}, number<Block_N>{}), {iM, 0});
|
||||
}();
|
||||
@@ -151,7 +149,7 @@ struct Rmsnorm2dFwd
|
||||
number<1>{});
|
||||
|
||||
const auto tmp2_ =
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<false>{});
|
||||
pad_tensor_view(tmp_, make_tuple(number<Block_N>{}), sequence<kPadM>{});
|
||||
|
||||
return make_tile_window(tmp2_, make_tuple(number<Block_N>{}), {0});
|
||||
}();
|
||||
|
||||
Reference in New Issue
Block a user