ck_moe: fix useless code and remove usless oob (#1972)

* fix useless code and remove usless oob

* clang format

---------

Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
feli
2025-03-13 00:22:42 +08:00
committed by GitHub
parent 4c97cc511e
commit 251afab3b7
2 changed files with 5 additions and 51 deletions

View File

@@ -221,16 +221,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
auto gather_offset =
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
// maintain a container record is_src_valid, waiting for RunWrite use.
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
const bool is_src_valid =
ld_offset <
src_desc
.GetElementSpaceSize(); // hack felix, todo use coord
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
// src_coord_) && (gather_offset < 32*512);
src_oob_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
.template SetAsType<bool>(src_data_idx_seq, true);
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
using src_vector_t = typename src_vector_type::type;
@@ -399,10 +392,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<vector_t>(src_data_idx_seq);
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
.template GetAsType<bool>(src_data_idx_seq);
auto op_r_v = is_src_valid ? op_r : vector_t(0);
auto op_r_v = op_r;
src_thread_scratch_tuple_(thread_scratch_id)
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);

View File

@@ -127,10 +127,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
{
static_for<0, nDst, 1>{}([&](auto i) {
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
// printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x,
// dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1],
// dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3],
// dst_coords_(i).GetOffset());
});
}
@@ -182,9 +178,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
"scatter weight dim, should only one vec");
constexpr auto iScatter =
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value,
// scatter_weights(Number<iScatter>{}));
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
src_vectors(i).template AsType<float>()(j) =
scatter_weights(Number<iScatter>{});
@@ -196,16 +189,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
using DataType = remove_cvref_t<decltype(data_types[i])>;
const auto tmp =
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x,
// i.value, src_coords_[i].GetOffset(), tmp);
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
}
else
{
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d vn\n", blockIdx.y, threadIdx.x, i.value);
src_vectors(i).template AsType<src_vector_t>()(I0) =
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
}
@@ -442,29 +430,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
}
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
constexpr InMemoryDataOperationEnum DstInMemOp =
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
// if(threadIdx.x==0)
// printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(),
// scatter_offset );
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
// static_for<0, 1, 1>{}([&](auto idx) {
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
// using print_vec_t = typename vector_type<DstData, 1>::type;
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset,
// is_dst_valid, type_convert<float>(dst_vectors[i].template
// AsType<print_vec_t>()[idx]));
// });
// }
dst_offset, true, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
});
// move coordinate
@@ -478,10 +450,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
// ordered_gather_dim);
});
return step_;
@@ -555,10 +523,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) =
(i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
// if(threadIdx.x==0)
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
// ordered_gather_dim);
});
return step_;