mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
add tag for gather index
This commit is contained in:
@@ -57,7 +57,7 @@ struct indexing_adaptor_onshot_cached
|
||||
return ck_tile::is_known_at_compile_time<IndexingType>::value;
|
||||
}
|
||||
};
|
||||
|
||||
#define Using_Gather 1
|
||||
template <typename IndexingType>
|
||||
struct indexing_adaptor
|
||||
{
|
||||
@@ -65,8 +65,10 @@ struct indexing_adaptor
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor() = default;
|
||||
CK_TILE_HOST_DEVICE constexpr indexing_adaptor(const IndexingType* idx) : cached_idx_(idx) {}
|
||||
const IndexingType* cached_idx_;
|
||||
#if Using_Gather
|
||||
mutable index_t pre_up_index_ = 0;
|
||||
mutable index_t pre_low_index_ = 0;
|
||||
#endif
|
||||
|
||||
template <typename LowIdx, typename UpIdx>
|
||||
CK_TILE_HOST_DEVICE constexpr void calculate_lower_index(LowIdx& idx_low,
|
||||
@@ -76,7 +78,7 @@ struct indexing_adaptor
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
idx_low(number<0>{}) = *(cached_idx_ + idx_up[number<0>{}]);
|
||||
|
||||
#if Using_Gather
|
||||
pre_up_index_ = idx_up[number<0>{}];
|
||||
pre_low_index_ = idx_low(number<0>{});
|
||||
#if 0
|
||||
@@ -84,6 +86,7 @@ struct indexing_adaptor
|
||||
{
|
||||
printf("\n first index from %d to %d \n", idx_up[number<0>{}], idx_low(number<0>{}));
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -97,7 +100,9 @@ struct indexing_adaptor
|
||||
static_assert(LowIdxDiff::size() == 1 && UpIdxDiff::size() == 1 && LowIdx::size() == 1 &&
|
||||
UpIdx::size() == 1,
|
||||
"wrong! inconsistent # of dimension");
|
||||
|
||||
#if !Using_Gather
|
||||
idx_diff_low(number<0>{}) = idx_diff_up[number<0>{}];
|
||||
#else
|
||||
int up_index = idx_diff_up[number<0>{}] + pre_up_index_;
|
||||
int low_index = *(cached_idx_ + up_index);
|
||||
idx_diff_low(number<0>{}) = low_index - pre_low_index_;
|
||||
@@ -113,6 +118,7 @@ struct indexing_adaptor
|
||||
idx_diff_up[number<0>{}],
|
||||
idx_diff_low(number<0>{}));
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// pass the diff to lower, but not changing the actually index
|
||||
|
||||
Reference in New Issue
Block a user