mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm (#2540)
* [CK_TILE] Correct BlockWarps calculation and fix smoke-test in rmsnorm * Update rmsnorm host reference * Update tree reduction of rmsnorm for reference host * Fix cross warp for m > 1 cases * Add RMSNorm model selectable option for host reference * Fix save_unquant cases * Update reference rmsnorm forward function to use enum for model sensitivity * Update reference rmsnorm calculation for model sensitivity * Fix m warp for layernorm * Adjust parameter of reference for twoPass * Fix clang format * Run clang-format-overwrite.sh to fix formating issue * fix clang format --------- Co-authored-by: MHYang <mengyang@amd.com> Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com>
This commit is contained in:
@@ -400,11 +400,13 @@ struct BlockReduce2dTreeCrossWarpSync
|
||||
block_sync_lds();
|
||||
|
||||
// We let each warp holds a duplication to do reduction.
|
||||
const index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
const index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
DataType v = 0;
|
||||
if(lane_id < num_reduce_warps)
|
||||
{
|
||||
v = smem_ptr[lane_id + i * num_warps];
|
||||
v = smem_ptr[i * num_warps + local_smem_os + lane_id];
|
||||
}
|
||||
|
||||
// cross-lane reduce for replication
|
||||
|
||||
Reference in New Issue
Block a user