mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
topk_softmax (#1592)
* topk_softmax * remove some file * fix atomix linear_offset * address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
@@ -4,9 +4,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <tuple>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
* TODO: block_tile_reduce_sync() currently has a limitation
|
||||
* Y dim must have at least one dim not been reduced
|
||||
*/
|
||||
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
|
||||
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
|
||||
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
@@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
* this version is faster, using xor to do reduce, no need broadcast anymore
|
||||
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
|
||||
*/
|
||||
template <typename AccDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// loop over thread data
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local = acc_tensor.get_thread_buffer()[i];
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// xor
|
||||
index_t src_lane =
|
||||
__lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
acc_tensor.get_thread_buffer()(i) = v_local;
|
||||
});
|
||||
}
|
||||
|
||||
// FIXME: this is for 2D to 1D reduce only, need to support n-D
|
||||
template <typename AccDistributedTensor_,
|
||||
typename InDistributedTensor_,
|
||||
@@ -175,6 +239,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* TODO: block_tile_reduce() currently has a limitation
|
||||
* Y dim must have at least one dim not been reduced
|
||||
*/
|
||||
template <typename AccDataType_,
|
||||
typename InDistributedTensor_,
|
||||
index_t... InReduceDims,
|
||||
@@ -208,4 +276,106 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
|
||||
return acc_tensor;
|
||||
}
|
||||
|
||||
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
|
||||
// this version only support in/acc/out datatypes are the same
|
||||
// this version will call thread/warp+sync in one function call
|
||||
//
|
||||
template <typename InDistributedTensor_>
|
||||
struct BlockReduce2D
|
||||
{
|
||||
using InDistributedTensor = remove_cvref_t<InDistributedTensor_>;
|
||||
using InDataType = typename InDistributedTensor::DataType;
|
||||
|
||||
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor& t_, const InDataType& reduce_init_)
|
||||
: t(t_), reduce_init(reduce_init_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
|
||||
{
|
||||
using ReduceDim = sequence<1>; // hard coded
|
||||
constexpr auto acc_dstr =
|
||||
make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding(
|
||||
InDistributedTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
ReduceDim{}));
|
||||
|
||||
return make_static_distributed_tensor<InDataType>(acc_dstr);
|
||||
}
|
||||
|
||||
// return number of pixels each lane need to reduce
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const
|
||||
{
|
||||
constexpr auto spans = InDistributedTensor::get_distributed_spans();
|
||||
}
|
||||
|
||||
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
|
||||
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
|
||||
// internally
|
||||
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
|
||||
// element , and the first element is always ignored For simplicity, will always try from
|
||||
// right-to-left to find alone which Y dim to split
|
||||
template <typename ReduceFunc,
|
||||
typename ReduceSyncFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
|
||||
const ReduceSyncFunc& reduce_sync_func,
|
||||
ReducePacksPerXDim = {}) const
|
||||
{
|
||||
constexpr auto spans = InDistributedTensor::get_distributed_spans();
|
||||
|
||||
constexpr auto row_y_unpacks = [&]() {
|
||||
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
|
||||
constexpr auto row_y_size =
|
||||
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
|
||||
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
|
||||
|
||||
static_assert(row_y_size % row_y_packs == 0);
|
||||
|
||||
constexpr auto row_y_slice_size = row_y_size / row_y_packs;
|
||||
|
||||
constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
|
||||
constexpr auto unpacks = slice_info[number<1>{}];
|
||||
return unpacks;
|
||||
}();
|
||||
|
||||
auto acc_tensor = MakeDstBlockTile();
|
||||
|
||||
// in-thread reduction
|
||||
// FIXME: hard coded to be 2D to 1D reduction
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
|
||||
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto acc = acc_tensor[acc_dstr_idx];
|
||||
|
||||
sweep_tile_uspan(
|
||||
spans[number<1>{}],
|
||||
[&](auto... dstr_idx_i1) {
|
||||
acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
|
||||
},
|
||||
row_y_unpacks);
|
||||
|
||||
acc_tensor(acc_dstr_idx) = acc;
|
||||
});
|
||||
|
||||
// TODO: always use xor to do cross-lane reduce
|
||||
block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
|
||||
|
||||
return acc_tensor;
|
||||
}
|
||||
|
||||
template <typename ReduceFunc>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
|
||||
{
|
||||
return operator()(reduce_func, reduce_func);
|
||||
}
|
||||
|
||||
InDistributedTensor t;
|
||||
InDataType reduce_init;
|
||||
};
|
||||
|
||||
// deduction guide
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user