Merge commit 'e1b0bdfbfa92f47006fdbced627c7470eacdea2b' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-13 19:10:56 +00:00
parent 713691609c
commit 63d907604b
7 changed files with 217 additions and 67 deletions

View File

@@ -400,11 +400,13 @@ struct BlockReduce2dTreeCrossWarpSync
block_sync_lds();
// We let each warp holds a duplication to do reduction.
const index_t local_warp_id = warp_id / num_reduce_warps;
const index_t local_smem_os = local_warp_id * num_reduce_warps;
static_for<0, thread_buf_size, 1>{}([&](auto i) {
DataType v = 0;
if(lane_id < num_reduce_warps)
{
v = smem_ptr[lane_id + i * num_warps];
v = smem_ptr[i * num_warps + local_smem_os + lane_id];
}
// cross-lane reduce for replication

View File

@@ -146,7 +146,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
// compute mean square each-thread->cross-lane->cross-warp
auto square_sum = block_reduce2d.template MakeYBlockTile<decltype(acc)>();
set_tile(square_sum, 0);
if constexpr(Problem::BlockShape::Vector_N % 2 == 0)
if constexpr((Problem::BlockShape::Repeat_N * Problem::BlockShape::Vector_N) % 2 == 0)
{
sweep_tile(
acc,
@@ -179,7 +179,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
const auto gamma_ = type_convert<ComputeDataType>(gamma[j_idx]);
if constexpr(std::is_same_v<YResidualDataType, ck_tile::bf16_t>)
if constexpr(std::is_same_v<XDataType, ck_tile::bf16_t>)
{
const auto tmp0 =
float_to_bf16<bf16_rounding_mode::standard>(acc[idx] * inv_rms_[i_idx]);
@@ -190,7 +190,7 @@ struct Rmsnorm2dFwdPipelineModelSensitiveT5Pass
}
else
{
const auto tmp = type_convert<YResidualDataType>(acc[idx] * inv_rms_[i_idx]);
const auto tmp = type_convert<XDataType>(acc[idx] * inv_rms_[i_idx]);
const auto rmsn_ = type_convert<ComputeDataType>(tmp) * gamma_;
rmsn(idx) = rmsn_;
}