From e3c5b2ae80f68939cc7b9e91f4fdbb81a1163ec3 Mon Sep 17 00:00:00 2001 From: feli Date: Thu, 13 Mar 2025 00:22:42 +0800 Subject: [PATCH] ck_moe: fix useless code and remove usless oob (#1972) * fix useless code and remove usless oob * clang format --------- Co-authored-by: coderfeli [ROCm/composable_kernel commit: 251afab3b79190c0640ab36103054835a8cde6df] --- ...wise_tensor_slice_transfer_v3r1_gather.hpp | 14 +------ ...ise_tensor_slice_transfer_v7r3_scatter.hpp | 42 ++----------------- 2 files changed, 5 insertions(+), 51 deletions(-) diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp index 76fc18bc14..bb9a452761 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1_gather.hpp @@ -221,16 +221,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather auto gather_offset = gather_offsets_(ordered_src_access_idx[Number{}]); - // maintain a container record is_src_valid, waiting for RunWrite use. const index_t ld_offset = src_coord_.GetOffset() + gather_offset; - const bool is_src_valid = - ld_offset < - src_desc - .GetElementSpaceSize(); // hack felix, todo use coord - // coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, - // src_coord_) && (gather_offset < 32*512); src_oob_thread_scratch_tuple_(thread_scratch_id) - .template SetAsType(src_data_idx_seq, is_src_valid); + .template SetAsType(src_data_idx_seq, true); using src_vector_type = vector_type_maker_t; using src_vector_t = typename src_vector_type::type; @@ -399,10 +392,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather auto op_r = src_thread_scratch_tuple_(thread_scratch_id) .template GetAsType(src_data_idx_seq); - const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id) - .template GetAsType(src_data_idx_seq); - - auto op_r_v = is_src_valid ? op_r : vector_t(0); + auto op_r_v = op_r; src_thread_scratch_tuple_(thread_scratch_id) .template SetAsType(src_data_idx_seq, op_r_v); diff --git a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp index ea61f0bc7c..29570c94e3 100644 --- a/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp +++ b/include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v7r3_scatter.hpp @@ -127,10 +127,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter { static_for<0, nDst, 1>{}([&](auto i) { dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]); - // printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x, - // dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1], - // dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3], - // dst_coords_(i).GetOffset()); }); } @@ -182,9 +178,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter "scatter weight dim, should only one vec"); constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number{}); - // if(threadIdx.x % 8 ==0 ) - // printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, - // scatter_weights(Number{})); static_for<0, SrcScalarPerVector, 1>{}([&](auto j) { src_vectors(i).template AsType()(j) = scatter_weights(Number{}); @@ -196,16 +189,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter using DataType = remove_cvref_t; const auto tmp = src_bufs[i].template Get(src_coords_[i].GetOffset(), true); - // if(threadIdx.x % 8 ==0 ) - // printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x, - // i.value, src_coords_[i].GetOffset(), tmp); static_for<0, SrcScalarPerVector, 1>{}( [&](auto j) { src_vectors(i).template AsType()(j) = tmp; }); } else { - // if(threadIdx.x % 8 ==0 ) - // printf("bid %d tid %d srcid %d vn\n", blockIdx.y, threadIdx.x, i.value); src_vectors(i).template AsType()(I0) = src_bufs[i].template Get(src_coords_[i].GetOffset(), true); } @@ -442,29 +430,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter } // copy data from buf_vectors into dst_bufs static_for<0, nDst, 1>{}([&](auto i) { - using dst_vector_t = typename remove_cvref_t::type; - auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); - const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize(); - // coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i], - // dst_coords_[i]); - + using dst_vector_t = typename remove_cvref_t::type; + auto dst_offset = scatter_offset + dst_coords_[i].GetOffset(); constexpr InMemoryDataOperationEnum DstInMemOp = static_cast(DstInMemOps::At(i.value)); - // if(threadIdx.x==0) - // printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(), - // scatter_offset ); dst_bufs(i).template Update( - dst_offset, is_dst_valid, dst_vectors[i].template AsType()[I0]); - // if(threadIdx.x%8 ==0 && blockIdx.x==0) { - // static_for<0, 1, 1>{}([&](auto idx) { - // using DstData = remove_cvref_t>; - // using print_vec_t = typename vector_type::type; - // printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, - // is_dst_valid, type_convert(dst_vectors[i].template - // AsType()[idx])); - // }); - // } + dst_offset, true, dst_vectors[i].template AsType()[I0]); }); // move coordinate @@ -478,10 +450,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter static_for<0, nDim, 1>{}([&](auto i) { step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i]; - - // if(threadIdx.x==0) - // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), - // ordered_gather_dim); }); return step_; @@ -555,10 +523,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter static_for<0, nDim, 1>{}([&](auto i) { step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number{}]; - - // if(threadIdx.x==0) - // printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i), - // ordered_gather_dim); }); return step_;