From 580d93dc8a1ef376cd534967246f1da4aa75ad8f Mon Sep 17 00:00:00 2001 From: letaoqin Date: Tue, 17 Dec 2024 03:55:45 +0000 Subject: [PATCH] rewrite save o --- example/ck_tile/17_fused_moe_general/main.cpp | 2 +- .../core/algorithm/indexing_adaptor.hpp | 14 ++++++++------ .../kernel/fused_moegemm_general_kernel.hpp | 18 +++++------------- .../fused_moegemm_pipeline_general.hpp | 16 +++++++++++----- 4 files changed, 25 insertions(+), 25 deletions(-) diff --git a/example/ck_tile/17_fused_moe_general/main.cpp b/example/ck_tile/17_fused_moe_general/main.cpp index bb9391b7db..713fab50a2 100644 --- a/example/ck_tile/17_fused_moe_general/main.cpp +++ b/example/ck_tile/17_fused_moe_general/main.cpp @@ -500,8 +500,8 @@ bool run(const ck_tile::ArgParser& arg_parser) auto o_dev = o_buf.ToHost(); auto c_dev = c_buf.ToHost(); std::cout << std::endl; - // std::cout << o_dev << std::endl; // std::cout << c_dev << std::endl; + std::cout << o_dev << std::endl; // int count = 0; // std::cout << "["; // for(int i = 0; i < tokens; i++) diff --git a/include/ck_tile/core/algorithm/indexing_adaptor.hpp b/include/ck_tile/core/algorithm/indexing_adaptor.hpp index c1d993125e..2d6f2ee315 100644 --- a/include/ck_tile/core/algorithm/indexing_adaptor.hpp +++ b/include/ck_tile/core/algorithm/indexing_adaptor.hpp @@ -81,7 +81,7 @@ struct indexing_adaptor #if Using_Gather pre_up_index_ = idx_up[number<0>{}]; pre_low_index_ = idx_low(number<0>{}); -#if 0 +#if 1 if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{})); @@ -93,8 +93,8 @@ struct indexing_adaptor template CK_TILE_HOST_DEVICE void update_lower_index(LowIdxDiff& idx_diff_low, const UpIdxDiff& idx_diff_up, - LowIdx& /*idx_low*/, - const UpIdx& /*idx_up*/) const + LowIdx& idx_low, + const UpIdx& idx_up) const { // TODO: nonthing changed here static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 && @@ -109,14 +109,16 @@ struct indexing_adaptor pre_up_index_ = up_index; pre_low_index_ = low_index; -#if 0 +#if 1 if(threadIdx.x == 0 && blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { - printf("\n index form %d to %d, diff from %d to %d \n", + printf("\n index form %d to %d, idx_diff_low %d, idx_diff_up: %d, idx_low: %d, idx_up: %d \n", up_index, low_index, + idx_diff_low(number<0>{}), idx_diff_up[number<0>{}], - idx_diff_low(number<0>{})); + idx_low(number<0>{}), + idx_up.at(number<0>{})); } #endif #endif diff --git a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp index b533c1e4e8..9ab946761c 100644 --- a/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp +++ b/include/ck_tile/ops/fused_moe/kernel/fused_moegemm_general_kernel.hpp @@ -252,13 +252,6 @@ struct FusedMoeGemmGlKernel index_t idx_n0 = __builtin_amdgcn_readfirstlane(intermediate_tile_id * BlockShape::Block_N0); - // const auto a_coord = Pipeline::GetACoord(); // 2d thread offset, [i_row, i_col] - // const auto sorted_token_id = a_coord[number<0>{}] + idx_m0; // start block_m - // // position - - // auto topk_weight = - // reinterpret_cast(kargs.sorted_weight_ptr)[sorted_token_id]; - const index_t* sorted_token_ids_ptr = reinterpret_cast(kargs.sorted_token_ids_ptr); @@ -375,18 +368,17 @@ struct FusedMoeGemmGlKernel }(); const auto w_window = [&]() { - const TopkWeightDataType* w_ptr = reinterpret_cast(kargs.sorted_weight_ptr); - const auto w_view_ = make_naive_tensor_view( + const TopkWeightDataType* w_ptr = + reinterpret_cast(kargs.sorted_weight_ptr); + const auto w_view_ = make_naive_tensor_view( w_ptr, make_tuple(kargs.max_num_tokens_padded), make_tuple(1), number<1>{}, number<1>{}); - const auto w_window_ = make_tile_window( - w_view_, - make_tuple(number{}), - {idx_m0}); + const auto w_window_ = + make_tile_window(w_view_, make_tuple(number{}), {idx_m0}); return w_window_; }(); diff --git a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp index eaa54e5934..33d5384dc4 100644 --- a/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp +++ b/include/ck_tile/ops/fused_moe/pipeline/fused_moegemm_pipeline_general.hpp @@ -348,22 +348,28 @@ struct FusedMoeGemmPipeline_General while(iCounter1 > 0) { clear_tile(o_acc); - block_sync_lds(); + block_sync_lds_direct_load(); gemm_1(o_acc, y, d); - block_sync_lds(); + move_tile_window(d_global_to_dram_window, {kN1, 0}); d = load_tile(d_global_to_dram_window); + // move out window and save data + tile_elementwise_inout([&weight](auto& x) { x = x * type_convert(weight); }, + o_acc); auto o = cast_tile(o_acc); - store_tile(o_window_, o); - move_tile_window(o_window_, {kN1, 0}); + store_tile(o_alds_win, o); + block_sync_lds(); + save_o(); + + move_tile_window(o_window_, {0, kN1}); iCounter1--; } // tail { clear_tile(o_acc); - block_sync_lds(); + block_sync_lds_direct_load(); gemm_1(o_acc, y, d); // block_sync_lds();