mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
ck_moe: fix useless code and remove usless oob (#1972)
* fix useless code and remove usless oob * clang format --------- Co-authored-by: coderfeli <coderfeli@163.com>
This commit is contained in:
@@ -221,16 +221,9 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
auto gather_offset =
|
||||
gather_offsets_(ordered_src_access_idx[Number<ordered_gather_dim>{}]);
|
||||
|
||||
// maintain a container record is_src_valid, waiting for RunWrite use.
|
||||
const index_t ld_offset = src_coord_.GetOffset() + gather_offset;
|
||||
const bool is_src_valid =
|
||||
ld_offset <
|
||||
src_desc
|
||||
.GetElementSpaceSize(); // hack felix, todo use coord
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc,
|
||||
// src_coord_) && (gather_offset < 32*512);
|
||||
src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<bool>(src_data_idx_seq, is_src_valid);
|
||||
.template SetAsType<bool>(src_data_idx_seq, true);
|
||||
|
||||
using src_vector_type = vector_type_maker_t<SrcData, SrcScalarPerVector>;
|
||||
using src_vector_t = typename src_vector_type::type;
|
||||
@@ -399,10 +392,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1_gather
|
||||
auto op_r = src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template GetAsType<vector_t>(src_data_idx_seq);
|
||||
|
||||
const bool is_src_valid = src_oob_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template GetAsType<bool>(src_data_idx_seq);
|
||||
|
||||
auto op_r_v = is_src_valid ? op_r : vector_t(0);
|
||||
auto op_r_v = op_r;
|
||||
|
||||
src_thread_scratch_tuple_(thread_scratch_id)
|
||||
.template SetAsType<vector_t>(src_data_idx_seq, op_r_v);
|
||||
|
||||
@@ -127,10 +127,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
{
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
dst_coords_(i) = make_tensor_coordinate(dst_descs[i], dst_slice_origin_idxs[i]);
|
||||
// printf("tid %d origin %d %d %d %d off %d\n", threadIdx.x,
|
||||
// dst_slice_origin_idxs[i][I0], dst_slice_origin_idxs[i][I1],
|
||||
// dst_slice_origin_idxs[i][I2], dst_slice_origin_idxs[i][I3],
|
||||
// dst_coords_(i).GetOffset());
|
||||
});
|
||||
}
|
||||
|
||||
@@ -182,9 +178,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
"scatter weight dim, should only one vec");
|
||||
constexpr auto iScatter =
|
||||
SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value,
|
||||
// scatter_weights(Number<iScatter>{}));
|
||||
static_for<0, SrcScalarPerVector, 1>{}([&](auto j) {
|
||||
src_vectors(i).template AsType<float>()(j) =
|
||||
scatter_weights(Number<iScatter>{});
|
||||
@@ -196,16 +189,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
using DataType = remove_cvref_t<decltype(data_types[i])>;
|
||||
const auto tmp =
|
||||
src_bufs[i].template Get<DataType>(src_coords_[i].GetOffset(), true);
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d off %d v %f\n", blockIdx.y, threadIdx.x,
|
||||
// i.value, src_coords_[i].GetOffset(), tmp);
|
||||
static_for<0, SrcScalarPerVector, 1>{}(
|
||||
[&](auto j) { src_vectors(i).template AsType<DataType>()(j) = tmp; });
|
||||
}
|
||||
else
|
||||
{
|
||||
// if(threadIdx.x % 8 ==0 )
|
||||
// printf("bid %d tid %d srcid %d vn\n", blockIdx.y, threadIdx.x, i.value);
|
||||
src_vectors(i).template AsType<src_vector_t>()(I0) =
|
||||
src_bufs[i].template Get<src_vector_t>(src_coords_[i].GetOffset(), true);
|
||||
}
|
||||
@@ -442,29 +430,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
}
|
||||
// copy data from buf_vectors into dst_bufs
|
||||
static_for<0, nDst, 1>{}([&](auto i) {
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
|
||||
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
|
||||
// dst_coords_[i]);
|
||||
|
||||
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
|
||||
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
|
||||
constexpr InMemoryDataOperationEnum DstInMemOp =
|
||||
static_cast<InMemoryDataOperationEnum>(DstInMemOps::At(i.value));
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("use tid %d off %d %d\n", threadIdx.x, dst_coords_[i].GetOffset(),
|
||||
// scatter_offset );
|
||||
dst_bufs(i).template Update<DstInMemOp, dst_vector_t>(
|
||||
dst_offset, is_dst_valid, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
|
||||
// static_for<0, 1, 1>{}([&](auto idx) {
|
||||
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
|
||||
// using print_vec_t = typename vector_type<DstData, 1>::type;
|
||||
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset,
|
||||
// is_dst_valid, type_convert<float>(dst_vectors[i].template
|
||||
// AsType<print_vec_t>()[idx]));
|
||||
// });
|
||||
// }
|
||||
dst_offset, true, dst_vectors[i].template AsType<dst_vector_t>()[I0]);
|
||||
});
|
||||
|
||||
// move coordinate
|
||||
@@ -478,10 +450,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : forward_step[i];
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
|
||||
// ordered_gather_dim);
|
||||
});
|
||||
|
||||
return step_;
|
||||
@@ -555,10 +523,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
|
||||
static_for<0, nDim, 1>{}([&](auto i) {
|
||||
step_(i) =
|
||||
(i.value == ScatterDim && OutputScatter) ? 0 : reset_step[Number<i>{}];
|
||||
|
||||
// if(threadIdx.x==0)
|
||||
// printf("i %d %d ordered_gather_dim %d\n", i.value, step_(i),
|
||||
// ordered_gather_dim);
|
||||
});
|
||||
|
||||
return step_;
|
||||
|
||||
Reference in New Issue
Block a user