mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 13:41:24 +00:00
topk_softmax (#1592)
* topk_softmax * remove some file * fix atomix linear_offset * address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
7
include/ck_tile/ops/elementwise.hpp
Normal file
7
include/ck_tile/ops/elementwise.hpp
Normal file
@@ -0,0 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
1163
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
Normal file
1163
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
|
||||
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
|
||||
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
|
||||
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
|
||||
|
||||
@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
|
||||
if constexpr(i_k0 < k0_loops - 1)
|
||||
move_tile_window(k_dram_window, {0, kK0});
|
||||
|
||||
async_load_fence(k_dram_window.get_num_access());
|
||||
async_load_fence(k_dram_window.get_num_of_access());
|
||||
__builtin_amdgcn_s_barrier();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
gemm_0(s_acc,
|
||||
|
||||
@@ -4,9 +4,14 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <tuple>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
* TODO: block_tile_reduce_sync() currently has a limitation
|
||||
* Y dim must have at least one dim not been reduced
|
||||
*/
|
||||
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
|
||||
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
|
||||
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
@@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
|
||||
});
|
||||
}
|
||||
|
||||
/*
|
||||
* this version is faster, using xor to do reduce, no need broadcast anymore
|
||||
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
|
||||
*/
|
||||
template <typename AccDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
using DstrEncodeDetail = typename DstrEncode::detail;
|
||||
|
||||
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
|
||||
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
|
||||
|
||||
constexpr index_t idim_p_lane = NDimP - 1;
|
||||
|
||||
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
// loop over thread data
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local = acc_tensor.get_thread_buffer()[i];
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
static_for<0, NDimR, 1>{}([&](auto idim_r) {
|
||||
// FIXME: nasty to use does_p_own_r_
|
||||
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
|
||||
{
|
||||
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
|
||||
|
||||
constexpr index_t lid_over_rid_derivative =
|
||||
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
|
||||
|
||||
static_assert(is_power_of_two_integer(r_length),
|
||||
"wrong! only support power of 2 reduction");
|
||||
|
||||
constexpr index_t nstage = integer_log2_floor(r_length);
|
||||
|
||||
// reduction sweep forward
|
||||
static_for<0, nstage, 1>{}([&](auto istage) {
|
||||
// xor
|
||||
index_t src_lane =
|
||||
__lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
acc_tensor.get_thread_buffer()(i) = v_local;
|
||||
});
|
||||
}
|
||||
|
||||
// FIXME: this is for 2D to 1D reduce only, need to support n-D
|
||||
template <typename AccDistributedTensor_,
|
||||
typename InDistributedTensor_,
|
||||
@@ -175,6 +239,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
|
||||
#endif
|
||||
}
|
||||
|
||||
/*
|
||||
* TODO: block_tile_reduce() currently has a limitation
|
||||
* Y dim must have at least one dim not been reduced
|
||||
*/
|
||||
template <typename AccDataType_,
|
||||
typename InDistributedTensor_,
|
||||
index_t... InReduceDims,
|
||||
@@ -208,4 +276,106 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
|
||||
return acc_tensor;
|
||||
}
|
||||
|
||||
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
|
||||
// this version only support in/acc/out datatypes are the same
|
||||
// this version will call thread/warp+sync in one function call
|
||||
//
|
||||
template <typename InDistributedTensor_>
|
||||
struct BlockReduce2D
|
||||
{
|
||||
using InDistributedTensor = remove_cvref_t<InDistributedTensor_>;
|
||||
using InDataType = typename InDistributedTensor::DataType;
|
||||
|
||||
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor& t_, const InDataType& reduce_init_)
|
||||
: t(t_), reduce_init(reduce_init_)
|
||||
{
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
|
||||
{
|
||||
using ReduceDim = sequence<1>; // hard coded
|
||||
constexpr auto acc_dstr =
|
||||
make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding(
|
||||
InDistributedTensor::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
ReduceDim{}));
|
||||
|
||||
return make_static_distributed_tensor<InDataType>(acc_dstr);
|
||||
}
|
||||
|
||||
// return number of pixels each lane need to reduce
|
||||
CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const
|
||||
{
|
||||
constexpr auto spans = InDistributedTensor::get_distributed_spans();
|
||||
}
|
||||
|
||||
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
|
||||
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
|
||||
// internally
|
||||
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
|
||||
// element , and the first element is always ignored For simplicity, will always try from
|
||||
// right-to-left to find alone which Y dim to split
|
||||
template <typename ReduceFunc,
|
||||
typename ReduceSyncFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
|
||||
const ReduceSyncFunc& reduce_sync_func,
|
||||
ReducePacksPerXDim = {}) const
|
||||
{
|
||||
constexpr auto spans = InDistributedTensor::get_distributed_spans();
|
||||
|
||||
constexpr auto row_y_unpacks = [&]() {
|
||||
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
|
||||
constexpr auto row_y_size =
|
||||
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
|
||||
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
|
||||
|
||||
static_assert(row_y_size % row_y_packs == 0);
|
||||
|
||||
constexpr auto row_y_slice_size = row_y_size / row_y_packs;
|
||||
|
||||
constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
|
||||
constexpr auto unpacks = slice_info[number<1>{}];
|
||||
return unpacks;
|
||||
}();
|
||||
|
||||
auto acc_tensor = MakeDstBlockTile();
|
||||
|
||||
// in-thread reduction
|
||||
// FIXME: hard coded to be 2D to 1D reduction
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
|
||||
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
|
||||
|
||||
auto acc = acc_tensor[acc_dstr_idx];
|
||||
|
||||
sweep_tile_uspan(
|
||||
spans[number<1>{}],
|
||||
[&](auto... dstr_idx_i1) {
|
||||
acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
|
||||
},
|
||||
row_y_unpacks);
|
||||
|
||||
acc_tensor(acc_dstr_idx) = acc;
|
||||
});
|
||||
|
||||
// TODO: always use xor to do cross-lane reduce
|
||||
block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
|
||||
|
||||
return acc_tensor;
|
||||
}
|
||||
|
||||
template <typename ReduceFunc>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
|
||||
{
|
||||
return operator()(reduce_func, reduce_func);
|
||||
}
|
||||
|
||||
InDistributedTensor t;
|
||||
InDataType reduce_init;
|
||||
};
|
||||
|
||||
// deduction guide
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
8
include/ck_tile/ops/softmax.hpp
Normal file
8
include/ck_tile/ops/softmax.hpp
Normal file
@@ -0,0 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
|
||||
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
81
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
Normal file
81
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
Normal file
@@ -0,0 +1,81 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/reduce.hpp"
|
||||
|
||||
#define _BLOCK_SOFTMAX_USE_UNPACK2 0
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
simple 2d softmax implementation, along row (dim=1)
|
||||
requirement:
|
||||
1). each row is within a warp
|
||||
2). data type must be a dword
|
||||
*/
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockSoftmax2D
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using DataType = typename Problem::DataType;
|
||||
|
||||
template <typename DistributedTensor, index_t dim = 1>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
|
||||
{
|
||||
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
||||
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
|
||||
#if _BLOCK_SOFTMAX_USE_UNPACK2
|
||||
const auto f_max3 = [](auto e0, auto e1, auto e2) {
|
||||
float rtn;
|
||||
asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2));
|
||||
return rtn;
|
||||
};
|
||||
const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; };
|
||||
#endif
|
||||
|
||||
// compute row max
|
||||
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
|
||||
#if _BLOCK_SOFTMAX_USE_UNPACK2
|
||||
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
|
||||
#else
|
||||
auto row_max = reduce_row_max(f_max);
|
||||
#endif
|
||||
sweep_tile<DistributedTensor>([&](auto idx) {
|
||||
constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
y(idx) = exp(x[idx] - row_max[row_id]);
|
||||
});
|
||||
|
||||
// compute row sum
|
||||
auto reduce_row_sum = BlockReduce2D<decltype(y)>{y, DataType{0}};
|
||||
#if _BLOCK_SOFTMAX_USE_UNPACK2
|
||||
auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
|
||||
#else
|
||||
auto row_sum = reduce_row_sum(f_sum);
|
||||
#endif
|
||||
// reciprocal
|
||||
auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
|
||||
sweep_tile(row_sum, [&](auto idx) { r(idx) = DataType{1} / row_sum(idx); });
|
||||
|
||||
// scale
|
||||
sweep_tile<DistributedTensor>([&](auto idx) {
|
||||
constexpr auto row_id = make_tuple(idx[number<0>{}]);
|
||||
y(idx) = y(idx) * r(row_id);
|
||||
});
|
||||
}
|
||||
|
||||
template <typename DistributedTensor, index_t dim = 1>
|
||||
CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
|
||||
{
|
||||
auto y = DistributedTensor{}; // distributed tensor
|
||||
operator()(x, y, number<dim>{});
|
||||
return y;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,16 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename DataType_>
|
||||
struct BlockSoftmax2DProblem
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
8
include/ck_tile/ops/topk.hpp
Normal file
8
include/ck_tile/ops/topk.hpp
Normal file
@@ -0,0 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
|
||||
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
113
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
Normal file
113
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
Normal file
@@ -0,0 +1,113 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
simple 2d topk implementation, along row (dim=1)
|
||||
requirement:
|
||||
1). each row is within a warp
|
||||
*/
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
struct BlockTopkStream2D
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
|
||||
using DataType = typename Problem::DataType;
|
||||
using IndexType = typename Problem::IndexType;
|
||||
|
||||
// TODO: if DataType is subdword, need pack into single dword to use argmax
|
||||
struct ArgmaxPacket
|
||||
{
|
||||
DataType arg;
|
||||
index_t value;
|
||||
};
|
||||
|
||||
template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
|
||||
CK_TILE_DEVICE void operator()(const DistributedTensor& x,
|
||||
const OutWindow& out_window,
|
||||
const IdxWindow& idx_window,
|
||||
index_t k,
|
||||
number<dim> = {})
|
||||
{
|
||||
OutWindow out_window_tmp = out_window;
|
||||
IdxWindow idx_window_tmp = idx_window;
|
||||
static_assert(
|
||||
std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
|
||||
std::is_same_v<typename DistributedTensor::DataType, DataType>);
|
||||
static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
|
||||
|
||||
DistributedTensor x_tmp = x;
|
||||
constexpr auto dst_dist = typename IdxWindow::TileDstr{};
|
||||
|
||||
// argmax for topk
|
||||
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
|
||||
return e0.arg > e1.arg ? e0 : e1;
|
||||
};
|
||||
|
||||
for(index_t i_k = 0; i_k < k; i_k++)
|
||||
{
|
||||
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
|
||||
auto packet = [&]() {
|
||||
auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
|
||||
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ArgmaxPacket t;
|
||||
t.arg = x_tmp(i_j_idx); // !!! we reference x here
|
||||
t.value = tile_idx.at(number<1>{});
|
||||
tmp(i_j_idx) = t;
|
||||
});
|
||||
});
|
||||
return tmp;
|
||||
}();
|
||||
|
||||
auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
|
||||
auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
|
||||
block_tile_reduce_xor_sync(r, f_argmax);
|
||||
|
||||
auto o = make_static_distributed_tensor<DataType>(dst_dist);
|
||||
auto i = make_static_distributed_tensor<IndexType>(dst_dist);
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ArgmaxPacket tmp = r(i_j_idx);
|
||||
o(i_j_idx) = tmp.arg;
|
||||
i(i_j_idx) = tmp.value;
|
||||
});
|
||||
});
|
||||
|
||||
// update value
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
x.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
auto col_id = tile_idx.at(number<1>{});
|
||||
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
|
||||
: x_tmp(i_j_idx);
|
||||
});
|
||||
});
|
||||
|
||||
if(threadIdx.x % Problem::ColLanes == 0)
|
||||
{
|
||||
store_tile(out_window_tmp, o);
|
||||
store_tile(idx_window_tmp, i);
|
||||
}
|
||||
move_tile_window(out_window_tmp, {number<0>{}, number<1>{}});
|
||||
move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,22 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/*
|
||||
simple 2d topk implementation, along row (dim=1)
|
||||
requirement:
|
||||
1). each row is within a warp
|
||||
*/
|
||||
template <typename DataType_, typename IndexType_, index_t ColLanes_>
|
||||
struct BlockTopkStream2DProblem
|
||||
{
|
||||
using DataType = remove_cvref_t<DataType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
static constexpr index_t ColLanes = ColLanes_;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
10
include/ck_tile/ops/topk_softmax.hpp
Normal file
10
include/ck_tile/ops/topk_softmax.hpp
Normal file
@@ -0,0 +1,10 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
166
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
Normal file
166
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
Normal file
@@ -0,0 +1,166 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TopkSoftmaxHostArgs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
void* p_indices;
|
||||
index_t num_rows;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
};
|
||||
|
||||
template <typename Pipeline_>
|
||||
struct TopkSoftmaxKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using InputType = typename Problem::InputType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using IndexType = typename Problem::IndexType;
|
||||
|
||||
struct TopkSoftmaxKargs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
void* p_indices;
|
||||
index_t num_rows;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
};
|
||||
|
||||
using Kargs = TopkSoftmaxKargs;
|
||||
using Hargs = TopkSoftmaxHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
if constexpr(Problem::LaunchType > 0)
|
||||
{
|
||||
int num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return dim3(num_cu * Problem::LaunchType);
|
||||
}
|
||||
else
|
||||
{
|
||||
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
|
||||
const int num_blocks =
|
||||
(num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
|
||||
return dim3(num_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_input = h.p_input;
|
||||
k.p_output = h.p_output;
|
||||
k.p_indices = h.p_indices;
|
||||
k.num_rows = h.num_rows;
|
||||
k.num_experts = h.num_experts;
|
||||
k.topk = h.topk;
|
||||
k.stride_input = h.stride_input;
|
||||
k.stride_output = h.stride_output;
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
|
||||
|
||||
if(block_row_id > kargs.num_rows)
|
||||
return;
|
||||
|
||||
index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input);
|
||||
index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output);
|
||||
index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id);
|
||||
|
||||
const auto input_window = [&]() {
|
||||
const InputType* p_input =
|
||||
reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
|
||||
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_input,
|
||||
make_tuple(num_rows_rem, kargs.num_experts),
|
||||
make_tuple(kargs.stride_input, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
auto view = pad_tensor_view(
|
||||
tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
|
||||
sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
|
||||
|
||||
return make_tile_window(
|
||||
view,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
|
||||
{0, 0});
|
||||
}();
|
||||
|
||||
auto output_window = [&]() {
|
||||
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_output,
|
||||
make_tuple(num_rows_rem, kargs.topk),
|
||||
make_tuple(kargs.stride_output, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
auto view =
|
||||
pad_tensor_view(tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
|
||||
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
|
||||
// 2. we loop over topk 1-1, no need padding
|
||||
return make_tile_window(
|
||||
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
|
||||
}();
|
||||
|
||||
auto indices_window = [&]() {
|
||||
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_indices,
|
||||
make_tuple(num_rows_rem, kargs.topk),
|
||||
make_tuple(kargs.stride_output, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
auto view =
|
||||
pad_tensor_view(tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
|
||||
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
|
||||
// 2. we loop over topk 1-1, no need padding
|
||||
return make_tile_window(
|
||||
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
|
||||
}();
|
||||
|
||||
Pipeline{}(input_window,
|
||||
output_window,
|
||||
indices_window,
|
||||
kargs.num_rows,
|
||||
kargs.num_experts,
|
||||
kargs.topk,
|
||||
block_row_id);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
|
||||
struct TopkSoftmaxWarpPerRowPipeline
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
template <typename InputWindow, typename OutputWindow, typename IndexWindow>
|
||||
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
|
||||
OutputWindow& out_window,
|
||||
IndexWindow& idx_window,
|
||||
index_t rows,
|
||||
index_t experts,
|
||||
index_t k,
|
||||
index_t block_row_id)
|
||||
{
|
||||
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
auto inp_win = make_tile_window_linear_raw(
|
||||
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
|
||||
#else
|
||||
auto inp_win = make_tile_window_linear(
|
||||
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
|
||||
#endif
|
||||
auto out_win = make_tile_window_linear(out_window.get_bottom_tensor_view(),
|
||||
out_window.get_window_lengths(),
|
||||
out_window.get_window_origin(),
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
auto idx_win = make_tile_window_linear(idx_window.get_bottom_tensor_view(),
|
||||
idx_window.get_window_lengths(),
|
||||
idx_window.get_window_origin(),
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
auto softmax = Policy::template GetSoftmax<Problem>();
|
||||
auto topk = Policy::template GetTopk<Problem>();
|
||||
|
||||
const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock;
|
||||
|
||||
while(1)
|
||||
{
|
||||
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto x =
|
||||
load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
|
||||
buffer_load_fence(number<0>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#else
|
||||
auto x = load_tile(inp_win);
|
||||
#endif
|
||||
// cast and pad input data
|
||||
auto w = [&]() {
|
||||
#if 0
|
||||
auto w_ = cast_tile<WeightType>(x);
|
||||
|
||||
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
w_.get_tile_distribution(), i_j_idx);
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
// set to -INF if OOB so that later softmax can work properly
|
||||
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
|
||||
: w_(i_j_idx);
|
||||
});
|
||||
});
|
||||
return w_;
|
||||
#else
|
||||
auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
|
||||
auto w_f = [&](auto idx) {
|
||||
w_(idx) = type_convert<WeightType>(x(idx));
|
||||
const auto x_indices =
|
||||
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
w_(idx) =
|
||||
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
|
||||
};
|
||||
tile_sweeper ts{w_, w_f};
|
||||
ts();
|
||||
return w_;
|
||||
#endif
|
||||
}();
|
||||
|
||||
// softmax
|
||||
auto y = softmax(w);
|
||||
|
||||
topk(y, out_win, idx_win, k);
|
||||
|
||||
// check exit
|
||||
if constexpr(Problem::LaunchType == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_row_id += grid_rows_per_loop;
|
||||
if(block_row_id >= rows)
|
||||
break;
|
||||
}
|
||||
|
||||
move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}});
|
||||
move_tile_window(out_win, {grid_rows_per_loop, number<0>{}});
|
||||
move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/softmax.hpp"
|
||||
#include "ck_tile/ops/topk.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TopkSoftmaxWarpPerRowPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
// TODO: Y dim must have one dim that is not reduced
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<Problem::IssuesPerCol,
|
||||
Problem::WarpsPerBlock,
|
||||
Problem::RowsPerWarpPerColIssue>,
|
||||
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
|
||||
tuple<sequence<Problem::IssuesPerCol,
|
||||
Problem::WarpsPerBlock,
|
||||
Problem::RowsPerWarpPerColIssue>,
|
||||
sequence<1>>, // each row write out single element
|
||||
tuple<sequence<1>, sequence<1, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
|
||||
{
|
||||
using softmax_problem = BlockSoftmax2DProblem<typename Problem::WeightType>;
|
||||
return BlockSoftmax2D<softmax_problem>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
|
||||
{
|
||||
using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
|
||||
typename Problem::IndexType,
|
||||
Problem::LanesPerRow>;
|
||||
// Note: replicate is LanesPerRow
|
||||
return BlockTopkStream2D<topk_problem>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InputType_,
|
||||
typename WeightType_,
|
||||
typename IndexType_,
|
||||
index_t Experts_,
|
||||
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
|
||||
index_t BytesPerIssue_ = sizeof(InputType_),
|
||||
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
|
||||
index_t BlockSize_ = 256>
|
||||
struct TopkSoftmaxWarpPerRowProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using InputType = remove_cvref_t<InputType_>;
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t LaunchType = LaunchType_;
|
||||
static constexpr index_t Experts = Experts_;
|
||||
static constexpr index_t BytesPerIssue = BytesPerIssue_;
|
||||
static constexpr index_t IssuesPerCol = IssuesPerCol_;
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
|
||||
static_assert(BytesPerIssue % sizeof(InputType) == 0);
|
||||
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
|
||||
static_assert(Experts % VectorSize == 0);
|
||||
static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize);
|
||||
static_assert(WarpSize % LanesPerRow == 0);
|
||||
static constexpr index_t RowsPerWarpPerColIssue = WarpSize / LanesPerRow;
|
||||
static constexpr index_t RowsPerWarp = IssuesPerCol * RowsPerWarpPerColIssue;
|
||||
static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize);
|
||||
|
||||
static constexpr index_t WarpsPerBlock = BlockSize / WarpSize;
|
||||
static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user