diff --git a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp index 0d436b143f..0dbb20645a 100644 --- a/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp +++ b/include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp @@ -63,8 +63,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass make_tile_window(a_window_, Policy::template MakeABXBlockTileDistribution()); auto b_window = make_tile_window(b_window_, Policy::template MakeABXBlockTileDistribution()); - auto x_window = - make_tile_window(x_window_, Policy::template MakeABXBlockTileDistribution()); + auto x_window = [&]() { + if constexpr(kSaveX) + return make_tile_window(x_window_, + Policy::template MakeABXBlockTileDistribution()); + else + return x_window_; + }(); auto gamma_window = make_tile_window( gamma_window_, Policy::template MakeGammaBlockTileDistribution());