i4 support lds multiple shuffle

This commit is contained in:
mtgu0705
2025-03-04 11:06:15 +08:00
parent e3a2aa4f9a
commit 27fb28ed31
3 changed files with 132 additions and 100 deletions

View File

@@ -65,16 +65,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
const StaticallyIndexedArray<Index, nSrc>& src_block_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_block_slice_origins,
const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
const ElementwiseOperation& element_op)
: threadwise_transfer_(src_descs,
StaticallyIndexedArray<Index, nSrc>{},
dst_descs,
StaticallyIndexedArray<Index, nDst>{},
element_op,
scatter_offsets,
scatter_weights)
element_op)
{
static_assert(nSrc == SrcDatas::Size() && nSrc == SrcDescs::Size() &&
nSrc == ThreadTransferSrcResetCoordinateAfterRunFlags::Size() &&
@@ -129,12 +125,13 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
template <typename SrcBuffers, index_t ThreadScratchId = 0>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
threadwise_transfer_.RunRead(src_descs, src_bufs, thread_scratch_id);
threadwise_transfer_.RunRead(src_descs, src_bufs, scatter_weights, thread_scratch_id);
}
}
@@ -144,15 +141,16 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
template <typename DstBuffers, index_t ThreadScratchId = 0>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
if(ThreadGroup::GetNumOfThread() == thread_cluster_desc_.GetElementSize() or
ThreadGroup::GetThreadId() < thread_cluster_desc_.GetElementSize())
{
if constexpr(is_detected<is_tuple, decltype(dst_bufs)>::value)
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, thread_scratch_id);
threadwise_transfer_.RunWrite(dst_descs, dst_bufs, scatter_offsets, thread_scratch_id);
else
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), thread_scratch_id);
threadwise_transfer_.RunWrite(dst_descs, tie(dst_bufs), scatter_offsets, thread_scratch_id);
}
}
@@ -160,10 +158,12 @@ struct ThreadGroupTensorSliceTransfer_v7r3_scatter
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
template <index_t ISrc>

View File

@@ -1497,32 +1497,6 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights_0 = p_ds_grid[I0];
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
@@ -1558,9 +1532,7 @@ struct GridwiseMoeGemm
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op,
scatter_offsets,
scatter_weights};
c_element_op};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -1589,8 +1561,37 @@ struct GridwiseMoeGemm
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const float *p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
// if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
// printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
});
block_sync_lds();
// each thread write its data from VGPR to LDS
@@ -1608,7 +1609,10 @@ struct GridwiseMoeGemm
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
tie(c_grid_buf),
scatter_offsets,
scatter_weights
);
if constexpr(access_id < num_access - 1)
{
@@ -1664,16 +1668,32 @@ struct GridwiseMoeGemm
const index_t max_token_id = __builtin_amdgcn_readfirstlane(p_max_token_id[0]);
// constexpr int expert_tile_cnt[8] = {2, 1, 1, 2, 2, 2, 1, 2};
// const index_t b_block_id = blockIdx.x % problem.NBlock;
const index_t expert_block_id = NSwizzle ? blockIdx.x / problem.NBlock : blockIdx.y;
if (expert_block_id * MPerBlock >= max_token_id)
return;
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[expert_block_id]);
const auto block_mn = [&]() -> std::pair<int, int> {
if constexpr (NSwizzle)
{
const index_t expert_block_id = blockIdx.x / problem.NBlock;
const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
// const index_t expert_block_id = blockIdx.x / problem.NBlock; //
// const index_t es = __builtin_amdgcn_readfirstlane(p_max_token_id[expert_block_id + 1]);
// const index_t expert_swizzle = es > 0 ? es : 1; //p_max_token_id[expert_id + 1];
// const index_t expert_block_swizzle = expert_block_id / expert_swizzle;
// const index_t b_block_id_swizzle = blockIdx.x % (problem.NBlock * expert_swizzle);
// const index_t nid = __builtin_amdgcn_readfirstlane(b_block_id_swizzle % 8 + b_block_id_swizzle / (8 * expert_swizzle) * 8);
// const index_t mid = __builtin_amdgcn_readfirstlane(expert_block_swizzle * expert_swizzle + b_block_id_swizzle / 8 % expert_swizzle);
// if(threadIdx.x==0)
// printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, es, p_sorted_expert_ids[expert_block_id]);
const index_t ecnt_prefix = p_max_token_id[1+expert_id];
const index_t prefix_block = ecnt_prefix * problem.NBlock;
const index_t ecnt = p_max_token_id[2+expert_id] - ecnt_prefix;
const index_t expert_swizzle = ecnt > 0 ? ecnt : 1; //p_max_token_id[expert_id + 1]; // 2
const index_t bid_new = blockIdx.x - prefix_block;
const index_t nid = __builtin_amdgcn_readfirstlane(bid_new % 8 + bid_new / (8 * expert_swizzle) * 8);
const index_t mid = __builtin_amdgcn_readfirstlane(ecnt_prefix + bid_new / 8 % expert_swizzle);
// if(threadIdx.x==0)
// printf("block, %d, mid, %d, nid, %d, ecnt, %d, expert %d \n", blockIdx.x, mid, nid, ecnt, expert_id);
return {nid, mid};
} else {
return {blockIdx.x, blockIdx.y};
@@ -1681,7 +1701,7 @@ struct GridwiseMoeGemm
}();
const index_t block_n_id = block_mn.first;
const index_t block_m_id = block_mn.second;
const index_t expert_id = __builtin_amdgcn_readfirstlane(p_sorted_expert_ids[block_m_id]);
// if (threadIdx.x==0) {
// printf("bid %d, eid %d, es %d, esi %d, bsi %d, m %d, n %d\n", blockIdx.x, expert_id, expert_swizzle, expert_block_swizzle, b_block_id_swizzle, block_m_id, block_n_id);
// }
@@ -1695,7 +1715,7 @@ struct GridwiseMoeGemm
constexpr auto AMRepeats = MPerBlock / AMThreads;
const index_t token_pos = block_m_id * MPerBlock + threadIdx.x / AKThreads * AMRepeats;
if(token_pos >= max_token_id || token0 >= problem.NumTokens)
if(token_pos >= max_token_id || expert_block_id * MPerBlock >= max_token_id || token0 >= problem.NumTokens)
return;
StaticallyIndexedArray<index_t, AMRepeats> gather_offsets; //= p_sorted_token_ids[token_pos];
static_for<0, AMRepeats, 1>{}([&](auto m0) {
@@ -1989,32 +2009,6 @@ struct GridwiseMoeGemm
using CDEBlockTransferCluster =
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock;
const auto EGlobalMemoryDataOperation = CGlobalMemoryDataOperation;
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = MPerBlock / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats;
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
const float *p_sorted_weights_0 = p_ds_grid[I0];
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
// if(threadIdx.x % 16 == 0)
// printf("init off bid %d tid %d m %d off %d\n", blockIdx.y, threadIdx.x, m0(), scatter_offsets(m0));
});
constexpr index_t scatter_weight_idx = IsInputGemm ? 1 : 3; //hack fix felix
auto cde_block_copy_lds_and_global = ThreadGroupTensorSliceTransfer_v7r3_scatter<
ThisThreadBlock,
@@ -2050,9 +2044,7 @@ struct GridwiseMoeGemm
idx_c_ds_block_begin,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
make_tuple(make_multi_index(0, 0, block_n_id, 0)),
c_element_op,
scatter_offsets,
scatter_weights};
c_element_op};
auto c_grid_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_c_grid, c_grid_desc_mblock_mperblock_nblock_nperblock.GetElementSpaceSize());
@@ -2081,8 +2073,37 @@ struct GridwiseMoeGemm
CShuffleNXdlPerWavePerShuffle * NWave * NPerXdl>>{};
static_assert(num_access == sfc_cde_block.GetNumOfAccess(), "wrong!");
constexpr auto EMThreads = CDEBlockTransferCluster{}.At(I0) * CDEBlockTransferCluster{}.At(I1);
constexpr auto EMRepeats = CShuffleMXdlPerWavePerShuffle * MWave * MPerXdl / EMThreads;
constexpr auto ENThreads = CDEBlockTransferCluster{}.At(I2) * CDEBlockTransferCluster{}.At(I3);
const float *p_sorted_weights_0 = p_ds_grid[I0];
static_for<0, num_access, 1>{}([&](auto access_id) {
// make sure it's safe to write to LDS
StaticallyIndexedArray<index_t, EMRepeats> scatter_offsets; //= p_sorted_token_ids[c_token_pos];
StaticallyIndexedArray<float, EMRepeats> scatter_weights; //= for topk
// too hack here, 2 specific for topk weights, fixme
// const index_t topk_id[EMRepeats];// = (p_sorted_token_ids[block_m_id * MPerBlock] & 0xff000000) >> 24;
auto dstidx = sfc_cde_block.GetIndex(access_id);
const index_t c_token_pos = block_m_id * MPerBlock + threadIdx.x / ENThreads * EMRepeats + dstidx(I1);
static_for<0, EMRepeats, 1>{}([&](auto m0) {
const index_t fused_token = p_sorted_token_ids[c_token_pos + m0];
index_t token_offset = fused_token & 0xffffff;
float weight = p_sorted_weights_0[(c_token_pos + m0) * problem.StrideDs[0]];
if constexpr (IsInputGemm)
{
token_offset = token_offset * problem.TopK + (fused_token >> 24);
} else {
const float *p_sorted_weights_2 = p_ds_grid[I2];
weight = weight * p_sorted_weights_2[c_token_pos + m0];
}
// if(threadIdx.x % 8 == 0 && blockIdx.x == 0)
// printf("init off tid %d access %d tpos %d m %d off %d wei %f\n", threadIdx.x, dstidx(I1), c_token_pos, m0(), token_offset, weight);
scatter_offsets(m0) = token_offset * problem.N;
scatter_weights(m0) = weight;
});
block_sync_lds();
// each thread write its data from VGPR to LDS
@@ -2100,7 +2121,10 @@ struct GridwiseMoeGemm
c_ds_desc_refs,
c_ds_buf_refs,
tie(e_grid_desc_mblock_mperblock_nblock_nperblock),
tie(c_grid_buf));
tie(c_grid_buf),
scatter_offsets,
scatter_weights
);
if constexpr(access_id < num_access - 1)
{

View File

@@ -100,14 +100,10 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
const StaticallyIndexedArray<Index, nSrc>& src_slice_origins,
const DstDescs& dst_descs,
const StaticallyIndexedArray<Index, nDst>& dst_slice_origins,
const ElementwiseOperation& element_op,
const StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
const StaticallyIndexedArray<float, scatter_num> &scatter_weights)
const ElementwiseOperation& element_op)
: src_coords_(MakeCoordinates(src_descs, src_slice_origins)),
dst_coords_(MakeCoordinates(dst_descs, dst_slice_origins)),
element_op_(element_op),
scatter_offsets_(scatter_offsets),
scatter_weights_(scatter_weights)
element_op_(element_op)
{
static_assert(SliceLengths::At(Number<SrcVectorDim>{}) % SrcScalarPerVector == 0,
"wrong! cannot evenly divide");
@@ -158,6 +154,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<SrcDescs::Size() == SrcBuffers::Size(), bool> = false>
__device__ void RunRead(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
StaticallyIndexedArray<float, scatter_num> &scatter_weights,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
// loop over space-filling curve
@@ -181,9 +178,9 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
static_assert(SrcScalarPerVectors{}[Number<ScatterWeightIdx>{}] == 1, "scatter weight dim, should only one vec");
constexpr auto iScatter = SrcSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
// if(threadIdx.x % 8 ==0 )
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights_(Number<iScatter>{}));
// printf("bid %d tid %d srcid %d sv %f\n", blockIdx.y, threadIdx.x, i.value, scatter_weights(Number<iScatter>{}));
static_for<0, SrcScalarPerVector, 1>{}(
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights_(Number<iScatter>{}); });
[&](auto j) { src_vectors(i).template AsType<float>()(j) = scatter_weights(Number<iScatter>{}); });
}
else if constexpr(SrcScalarPerVectors{}[i] == 1)
{
@@ -418,6 +415,7 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
enable_if_t<DstDescs::Size() == 1 && DstBuffers::Size() == 1, bool> = false>
__device__ void RunWrite(const DstDescs& dst_descs,
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
Number<ThreadScratchId> thread_scratch_id = Number<ThreadScratchId>{})
{
OOBCheck(thread_scratch_id);
@@ -430,13 +428,13 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
if constexpr (OutputScatter)
{
constexpr auto iScatter = DstSpaceFillingCurve::GetIndex(iAccess)(Number<ScatterDim>{});
scatter_offset = scatter_offsets_(Number<iScatter>{});
scatter_offset = scatter_offsets(Number<iScatter>{});
}
// copy data from buf_vectors into dst_bufs
static_for<0, nDst, 1>{}([&](auto i) {
using dst_vector_t = typename remove_cvref_t<decltype(dst_vectors[i])>::type;
auto dst_offset = scatter_offset + dst_coords_[i].GetOffset();
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();//hack felix, todo use coord
const bool is_dst_valid = dst_offset < dst_descs[i].GetElementSpaceSize();
// coordinate_has_valid_offset_assuming_visible_index_is_valid(dst_descs[i],
// dst_coords_[i]);
@@ -449,11 +447,11 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
dst_offset,
is_dst_valid,
dst_vectors[i].template AsType<dst_vector_t>()[I0]);
// if(1) {
// static_for<0, DstScalarPerVector, 1>{}([&](auto idx) {
// if(threadIdx.x%8 ==0 && blockIdx.x==0) {
// static_for<0, 1, 1>{}([&](auto idx) {
// using DstData = remove_cvref_t<tuple_element_t<0, DstDatas>>;
// using print_vec_t = typename vector_type<DstData, 1>::type;
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_coords_[i].GetOffset(), is_dst_valid,
// printf("tid %d off %d valid %d %f\n",threadIdx.x, dst_offset, is_dst_valid,
// type_convert<float>(dst_vectors[i].template AsType<print_vec_t>()[idx]));
// });
// }
@@ -509,10 +507,12 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
__device__ void Run(const SrcDescs& src_descs,
const SrcBuffers& src_bufs,
const DstDescs& dst_descs,
DstBuffers dst_bufs)
DstBuffers dst_bufs,
StaticallyIndexedArray<index_t, scatter_num> &scatter_offsets,
StaticallyIndexedArray<float, scatter_num> &scatter_weights)
{
RunRead(src_descs, src_bufs);
RunWrite(dst_descs, dst_bufs);
RunRead(src_descs, src_bufs, scatter_weights);
RunWrite(dst_descs, dst_bufs, scatter_offsets);
}
__device__ static constexpr auto GetSrcCoordinateResetStep()
@@ -683,8 +683,18 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
? dst_slice_origin_step_idx
: dst_slice_origin_step_idx + GetDstCoordinateResetStep();
auto adjusted_step_idx_scatter = [&]()
{
Index step_;
static_for<0, nDim, 1>{}([&](auto i) {
step_(i) = (i.value == ScatterDim && OutputScatter) ? 0 : adjusted_step_idx[Number<i>{}];
});
return step_;
}
();
// is it OK to construct a new step every time?
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx);
const auto adjusted_step = make_tensor_coordinate_step(dst_descs[iDst], adjusted_step_idx_scatter);
move_tensor_coordinate(dst_descs[iDst], dst_coords_(iDst), adjusted_step);
}
@@ -709,8 +719,6 @@ struct ThreadwiseTensorSliceTransfer_v7r3_scatter
SrcCoords src_coords_;
DstCoords dst_coords_;
const ElementwiseOperation element_op_;
StaticallyIndexedArray<index_t, scatter_num> scatter_offsets_;
StaticallyIndexedArray<float, scatter_num> scatter_weights_;
};
} // namespace ck