From d846292c8506dc196db6ca4415e3da7aa77b47d3 Mon Sep 17 00:00:00 2001 From: letaoqin Date: Thu, 12 Dec 2024 14:22:32 +0000 Subject: [PATCH] rewite save o code --- .../pipeline/fused_moegemm_pipeline_general.hpp | 14 ++++++-------- .../fused_moegemm_pipeline_general_policy.hpp | 17 +++++++++-------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp index ee25fef679..00ab1c6834 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp @@ -306,25 +306,23 @@ struct FusedMoeGemmPipeline_General make_tuple(number<32>{}, number<32>{}), {0, 0}, Policy::template MakeGlobalTileDistribution_O()); - ignore = o_alds_win; auto save_o = [&]() { if(blockIdx.x == 0 && (blockIdx.y == 0 || blockIdx.y == 1) && blockIdx.z == 0) { if(threadIdx.x < 64) { - auto o0 = load_tile(o_olds_win); - for(int step = 1; step < 4; step++) - { + auto o0 = load_tile(o_olds_win); + constexpr index_t thread_buffer_size = decltype(o0)::get_thread_buffer_size(); + static_for<1, BlockShape::Repeat_K1, 1>{}([&](auto) { move_tile_window(o_olds_win, {32, 0}); auto o1 = load_tile(o_olds_win); - for(int i = 0; i < 16; i++) - { + static_for<0, thread_buffer_size, 1>{}([&](auto i) { o0.get_thread_buffer()(i) = type_convert( type_convert(o0.get_thread_buffer()[i]) + type_convert(o1.get_thread_buffer()[i])); - } - } + }); + }); update_tile(o_window_, o0); } } diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp index b220e332c9..2930b4d87c 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general_policy.hpp @@ -216,14 +216,15 @@ struct FusedMoeGemmPipelineGeneralPolicy typename S_::WarpTile_0>>; constexpr auto warp_gemm = GetWarpGemm0(); - using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy; + using BlockGemmPolicy = + BlockGemmASmemBSmemCRegV1CustomPolicy; return BlockGemmASmemBSmemCRegV1{}; // return BlockGemmASmemBRegCRegV1{};