From 9a22805e92e4a2f10312fc1a26fcf8b74b84c8ea Mon Sep 17 00:00:00 2001 From: rocking Date: Sun, 27 Oct 2024 11:42:58 +0000 Subject: [PATCH] Fix bug of kSaveX == false --- .../add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp index 0d436b143f..0dbb20645a 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -63,8 +63,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution()); auto b_window = make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution()); - auto x_window = - make_tile_window(x_window_, Policy::template MakeABXBlockTileDistribution()); + auto x_window = [&]() { + if constexpr(kSaveX) + return make_tile_window(x_window_, + Policy::template MakeABXBlockTileDistribution()); + else + return x_window_; + }(); auto gamma_window = make_tile_window( gamma_window_, Policy::template MakeGammaBlockTileDistribution());