topk_softmax (#1592)

* topk_softmax

* remove some file

* fix atomix linear_offset

* address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
carlushuang
2024-10-26 23:52:49 +08:00
committed by GitHub
parent 31bf253aeb
commit b098b71b05
41 changed files with 5603 additions and 226 deletions

View File

@@ -0,0 +1,7 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

File diff suppressed because it is too large Load Diff

View File

@@ -334,7 +334,7 @@ struct BlockFmhaPipelineQRKSVSAsync
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
buffer_load_fence(k_dram_window.get_num_of_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
@@ -359,7 +359,7 @@ struct BlockFmhaPipelineQRKSVSAsync
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
async_load_fence(k_dram_window.get_num_of_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,

View File

@@ -4,9 +4,14 @@
#pragma once
#include "ck_tile/core.hpp"
#include <tuple>
namespace ck_tile {
/*
* TODO: block_tile_reduce_sync() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
// synchronize reduce result (cross lane reduction and broadcast on replicated dimension)
template <typename AccDistributedTensor_, typename ReduceFunc, bool WithBroadcast = true>
CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
@@ -104,6 +109,65 @@ CK_TILE_DEVICE void block_tile_reduce_sync(AccDistributedTensor_& acc_tensor,
});
}
/*
* this version is faster, using xor to do reduce, no need broadcast anymore
* TODO: the limitation is to-be-reduced P dim can only mapping to one R dim?
*/
template <typename AccDistributedTensor_, typename ReduceFunc>
CK_TILE_DEVICE void block_tile_reduce_xor_sync(AccDistributedTensor_& acc_tensor,
const ReduceFunc& reduce_func)
{
using Dstr = typename AccDistributedTensor_::StaticTileDistribution;
using DstrEncode = typename Dstr::DstrEncode;
using DstrEncodeDetail = typename DstrEncode::detail;
constexpr index_t NDimP = Dstr::get_num_of_dimension_p();
constexpr index_t NDimR = Dstr::get_num_of_dimension_r();
constexpr index_t idim_p_lane = NDimP - 1;
constexpr index_t thread_buf_size = AccDistributedTensor_::get_thread_buffer_size();
// loop over thread data
static_for<0, thread_buf_size, 1>{}([&](auto i) {
auto v_local = acc_tensor.get_thread_buffer()[i];
// cross-lane reduce for replication
// only reduce on R dimension correspond to lane
// (lane id maps to this R dimension)
static_for<0, NDimR, 1>{}([&](auto idim_r) {
// FIXME: nasty to use does_p_own_r_
if constexpr(DstrEncodeDetail::does_p_own_r_[idim_p_lane][idim_r])
{
constexpr index_t r_length = DstrEncode::rs_lengths_[idim_r];
constexpr index_t lid_over_rid_derivative =
DstrEncodeDetail::ps_over_rs_derivative_[idim_p_lane][idim_r];
static_assert(is_power_of_two_integer(r_length),
"wrong! only support power of 2 reduction");
constexpr index_t nstage = integer_log2_floor(r_length);
// reduction sweep forward
static_for<0, nstage, 1>{}([&](auto istage) {
// xor
index_t src_lane =
__lane_id() ^ (number<lid_over_rid_derivative << istage.value>{}.value);
// pull data from remote lane
const auto v_remote = warp_shuffle(v_local, src_lane);
// reduce
v_local = reduce_func(v_local, v_remote);
});
}
});
acc_tensor.get_thread_buffer()(i) = v_local;
});
}
// FIXME: this is for 2D to 1D reduce only, need to support n-D
template <typename AccDistributedTensor_,
typename InDistributedTensor_,
@@ -175,6 +239,10 @@ CK_TILE_DEVICE void block_tile_reduce(AccDistributedTensor_& acc_tensor,
#endif
}
/*
* TODO: block_tile_reduce() currently has a limitation
* Y dim must have at least one dim not been reduced
*/
template <typename AccDataType_,
typename InDistributedTensor_,
index_t... InReduceDims,
@@ -208,4 +276,106 @@ CK_TILE_DEVICE auto block_tile_reduce(const InDistributedTensor_& in_tensor,
return acc_tensor;
}
// this version only support 2D->1D reduce (reduce-dim=seq<0, 1>)
// this version only support in/acc/out datatypes are the same
// this version will call thread/warp+sync in one function call
//
template <typename InDistributedTensor_>
struct BlockReduce2D
{
using InDistributedTensor = remove_cvref_t<InDistributedTensor_>;
using InDataType = typename InDistributedTensor::DataType;
CK_TILE_HOST_DEVICE BlockReduce2D(const InDistributedTensor& t_, const InDataType& reduce_init_)
: t(t_), reduce_init(reduce_init_)
{
}
CK_TILE_HOST_DEVICE constexpr auto MakeDstBlockTile() const
{
using ReduceDim = sequence<1>; // hard coded
constexpr auto acc_dstr =
make_static_tile_distribution(ck_tile::detail::make_reduce_tile_distribution_encoding(
InDistributedTensor::get_tile_distribution()
.get_static_tile_distribution_encoding(),
ReduceDim{}));
return make_static_distributed_tensor<InDataType>(acc_dstr);
}
// return number of pixels each lane need to reduce
CK_TILE_HOST_DEVICE constexpr auto get_reduce_length_y() const
{
constexpr auto spans = InDistributedTensor::get_distributed_spans();
}
// Here ReducePacksPerXDim is not the same meaning as that in static_uford/sweep_tile_uspan
// this is number of packs along the X-dim. We need to compute the Unpacks along the Y dim
// internally
// For simplicity, we just support along the row dimension, ReducePacksPerXDim is always 2
// element , and the first element is always ignored For simplicity, will always try from
// right-to-left to find alone which Y dim to split
template <typename ReduceFunc,
typename ReduceSyncFunc,
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func,
const ReduceSyncFunc& reduce_sync_func,
ReducePacksPerXDim = {}) const
{
constexpr auto spans = InDistributedTensor::get_distributed_spans();
constexpr auto row_y_unpacks = [&]() {
constexpr auto row_y_lengths = typename decltype(spans[number<1>{}])::Impl{};
constexpr auto row_y_size =
reduce_on_sequence(row_y_lengths, multiplies{}, number<1>{});
constexpr auto row_y_packs = ReducePacksPerXDim{}.at(number<1>{});
static_assert(row_y_size % row_y_packs == 0);
constexpr auto row_y_slice_size = row_y_size / row_y_packs;
constexpr auto slice_info = slice_sequence(row_y_lengths, number<row_y_slice_size>{});
constexpr auto unpacks = slice_info[number<1>{}];
return unpacks;
}();
auto acc_tensor = MakeDstBlockTile();
// in-thread reduction
// FIXME: hard coded to be 2D to 1D reduction
sweep_tile_span(spans[number<0>{}], [&](auto dstr_idx_i0) {
constexpr auto acc_dstr_idx = make_tuple(dstr_idx_i0);
auto acc = acc_tensor[acc_dstr_idx];
sweep_tile_uspan(
spans[number<1>{}],
[&](auto... dstr_idx_i1) {
acc = reduce_func(acc, t[make_tuple(dstr_idx_i0, dstr_idx_i1)]...);
},
row_y_unpacks);
acc_tensor(acc_dstr_idx) = acc;
});
// TODO: always use xor to do cross-lane reduce
block_tile_reduce_xor_sync(acc_tensor, reduce_sync_func);
return acc_tensor;
}
template <typename ReduceFunc>
CK_TILE_HOST_DEVICE auto operator()(const ReduceFunc& reduce_func) const
{
return operator()(reduce_func, reduce_func);
}
InDistributedTensor t;
InDataType reduce_init;
};
// deduction guide
template <typename T>
CK_TILE_HOST_DEVICE_EXTERN BlockReduce2D(const T&, const typename T::DataType&)->BlockReduce2D<T>;
} // namespace ck_tile

View File

@@ -0,0 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/softmax/block/block_softmax_2d.hpp"
#include "ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,81 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
#define _BLOCK_SOFTMAX_USE_UNPACK2 0
namespace ck_tile {
/*
simple 2d softmax implementation, along row (dim=1)
requirement:
1). each row is within a warp
2). data type must be a dword
*/
template <typename Problem_, typename Policy_ = void>
struct BlockSoftmax2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE void
operator()(const DistributedTensor& x, DistributedTensor& y, number<dim> = {})
{
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
#if _BLOCK_SOFTMAX_USE_UNPACK2
const auto f_max3 = [](auto e0, auto e1, auto e2) {
float rtn;
asm volatile("v_max3_f32 %0, %1, %2, %3" : "=v"(rtn) : "v"(e0), "v"(e1), "v"(e2));
return rtn;
};
const auto f_sum3 = [](auto e0, auto e1, auto e2) { return e0 + e1 + e2; };
#endif
// compute row max
auto reduce_row_max = BlockReduce2D{x, -numeric<DataType>::infinity()};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_max = reduce_row_max(f_max3, f_max, sequence<1, 2>{});
#else
auto row_max = reduce_row_max(f_max);
#endif
sweep_tile<DistributedTensor>([&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
y(idx) = exp(x[idx] - row_max[row_id]);
});
// compute row sum
auto reduce_row_sum = BlockReduce2D<decltype(y)>{y, DataType{0}};
#if _BLOCK_SOFTMAX_USE_UNPACK2
auto row_sum = reduce_row_sum(f_sum3, f_sum, sequence<1, 2>{});
#else
auto row_sum = reduce_row_sum(f_sum);
#endif
// reciprocal
auto r = make_static_distributed_tensor<DataType>(row_sum.get_tile_distribution());
sweep_tile(row_sum, [&](auto idx) { r(idx) = DataType{1} / row_sum(idx); });
// scale
sweep_tile<DistributedTensor>([&](auto idx) {
constexpr auto row_id = make_tuple(idx[number<0>{}]);
y(idx) = y(idx) * r(row_id);
});
}
template <typename DistributedTensor, index_t dim = 1>
CK_TILE_DEVICE decltype(auto) operator()(const DistributedTensor& x, number<dim> = {})
{
auto y = DistributedTensor{}; // distributed tensor
operator()(x, y, number<dim>{});
return y;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,16 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename DataType_>
struct BlockSoftmax2DProblem
{
using DataType = remove_cvref_t<DataType_>;
};
} // namespace ck_tile

View File

@@ -0,0 +1,8 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,113 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename Problem_, typename Policy_ = void>
struct BlockTopkStream2D
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using DataType = typename Problem::DataType;
using IndexType = typename Problem::IndexType;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct ArgmaxPacket
{
DataType arg;
index_t value;
};
template <typename DistributedTensor, typename OutWindow, typename IdxWindow, index_t dim = 1>
CK_TILE_DEVICE void operator()(const DistributedTensor& x,
const OutWindow& out_window,
const IdxWindow& idx_window,
index_t k,
number<dim> = {})
{
OutWindow out_window_tmp = out_window;
IdxWindow idx_window_tmp = idx_window;
static_assert(
std::is_same_v<typename DistributedTensor::DataType, typename OutWindow::DataType> &&
std::is_same_v<typename DistributedTensor::DataType, DataType>);
static_assert(std::is_same_v<typename IdxWindow::DataType, IndexType>);
DistributedTensor x_tmp = x;
constexpr auto dst_dist = typename IdxWindow::TileDstr{};
// argmax for topk
const auto f_argmax = [](ArgmaxPacket e0, ArgmaxPacket e1) {
return e0.arg > e1.arg ? e0 : e1;
};
for(index_t i_k = 0; i_k < k; i_k++)
{
constexpr auto span_2d = DistributedTensor::get_distributed_spans();
auto packet = [&]() {
auto tmp = make_static_distributed_tensor<ArgmaxPacket>(x.get_tile_distribution());
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
tmp.get_tile_distribution(), make_tuple(idx0, idx1));
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket t;
t.arg = x_tmp(i_j_idx); // !!! we reference x here
t.value = tile_idx.at(number<1>{});
tmp(i_j_idx) = t;
});
});
return tmp;
}();
auto argmax_init = ArgmaxPacket{-numeric<DataType>::infinity(), 0};
auto r = block_tile_reduce<ArgmaxPacket>(packet, sequence<1>{}, f_argmax, argmax_init);
block_tile_reduce_xor_sync(r, f_argmax);
auto o = make_static_distributed_tensor<DataType>(dst_dist);
auto i = make_static_distributed_tensor<IndexType>(dst_dist);
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
ArgmaxPacket tmp = r(i_j_idx);
o(i_j_idx) = tmp.arg;
i(i_j_idx) = tmp.value;
});
});
// update value
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
x.get_tile_distribution(), make_tuple(idx0, idx1));
auto col_id = tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
x_tmp(i_j_idx) = (col_id == r(i_j_idx).value) ? -numeric<DataType>::infinity()
: x_tmp(i_j_idx);
});
});
if(threadIdx.x % Problem::ColLanes == 0)
{
store_tile(out_window_tmp, o);
store_tile(idx_window_tmp, i);
}
move_tile_window(out_window_tmp, {number<0>{}, number<1>{}});
move_tile_window(idx_window_tmp, {number<0>{}, number<1>{}});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,22 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template <typename DataType_, typename IndexType_, index_t ColLanes_>
struct BlockTopkStream2DProblem
{
using DataType = remove_cvref_t<DataType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t ColLanes = ColLanes_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,10 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,166 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
struct TopkSoftmaxHostArgs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
};
template <typename Pipeline_>
struct TopkSoftmaxKernel
{
using Pipeline = remove_cvref_t<Pipeline_>;
using Problem = remove_cvref_t<typename Pipeline::Problem>;
using InputType = typename Problem::InputType;
using WeightType = typename Problem::WeightType;
using IndexType = typename Problem::IndexType;
struct TopkSoftmaxKargs
{
const void* p_input;
void* p_output;
void* p_indices;
index_t num_rows;
index_t num_experts;
index_t topk;
index_t stride_input; // row stride for input, at least experts
index_t stride_output; // row stride for output/indices, at least tpok
};
using Kargs = TopkSoftmaxKargs;
using Hargs = TopkSoftmaxHostArgs;
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
{
if constexpr(Problem::LaunchType > 0)
{
int num_cu = [&]() {
hipDeviceProp_t dev_prop;
hipDevice_t dev;
HIP_CHECK_ERROR(hipGetDevice(&dev));
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
return dev_prop.multiProcessorCount;
}();
return dim3(num_cu * Problem::LaunchType);
}
else
{
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
const int num_blocks =
(num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
return dim3(num_blocks);
}
}
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
{
Kargs k;
k.p_input = h.p_input;
k.p_output = h.p_output;
k.p_indices = h.p_indices;
k.num_rows = h.num_rows;
k.num_experts = h.num_experts;
k.topk = h.topk;
k.stride_input = h.stride_input;
k.stride_output = h.stride_output;
return k;
}
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
if(block_row_id > kargs.num_rows)
return;
index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input);
index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output);
index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id);
const auto input_window = [&]() {
const InputType* p_input =
reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_input,
make_tuple(num_rows_rem, kargs.num_experts),
make_tuple(kargs.stride_input, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view = pad_tensor_view(
tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
return make_tile_window(
view,
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
{0, 0});
}();
auto output_window = [&]() {
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_output,
make_tuple(num_rows_rem, kargs.topk),
make_tuple(kargs.stride_output, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view =
pad_tensor_view(tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}();
auto indices_window = [&]() {
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
auto tmp = make_naive_tensor_view<address_space_enum::global>(
p_indices,
make_tuple(num_rows_rem, kargs.topk),
make_tuple(kargs.stride_output, 1),
number<Problem::VectorSize>{},
number<1>{});
auto view =
pad_tensor_view(tmp,
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return make_tile_window(
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
}();
Pipeline{}(input_window,
output_window,
indices_window,
kargs.num_rows,
kargs.num_experts,
kargs.topk,
block_row_id);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,123 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace ck_tile {
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
struct TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using WeightType = typename Problem::WeightType;
template <typename InputWindow, typename OutputWindow, typename IndexWindow>
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
OutputWindow& out_window,
IndexWindow& idx_window,
index_t rows,
index_t experts,
index_t k,
index_t block_row_id)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
auto inp_win = make_tile_window_linear_raw(
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
#else
auto inp_win = make_tile_window_linear(
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
#endif
auto out_win = make_tile_window_linear(out_window.get_bottom_tensor_view(),
out_window.get_window_lengths(),
out_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto idx_win = make_tile_window_linear(idx_window.get_bottom_tensor_view(),
idx_window.get_window_lengths(),
idx_window.get_window_origin(),
Policy::template MakeOutputDistribution<Problem>());
auto softmax = Policy::template GetSoftmax<Problem>();
auto topk = Policy::template GetTopk<Problem>();
const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock;
while(1)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier(0);
auto x =
load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
buffer_load_fence(number<0>{});
__builtin_amdgcn_sched_barrier(0);
#else
auto x = load_tile(inp_win);
#endif
// cast and pad input data
auto w = [&]() {
#if 0
auto w_ = cast_tile<WeightType>(x);
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
w_.get_tile_distribution(), i_j_idx);
const auto current_expert = x_indices.at(number<1>{});
// set to -INF if OOB so that later softmax can work properly
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
: w_(i_j_idx);
});
});
return w_;
#else
auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
auto w_f = [&](auto idx) {
w_(idx) = type_convert<WeightType>(x(idx));
const auto x_indices =
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
const auto current_expert = x_indices.at(number<1>{});
w_(idx) =
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
};
tile_sweeper ts{w_, w_f};
ts();
return w_;
#endif
}();
// softmax
auto y = softmax(w);
topk(y, out_win, idx_win, k);
// check exit
if constexpr(Problem::LaunchType == 0)
{
break;
}
else
{
block_row_id += grid_rows_per_loop;
if(block_row_id >= rows)
break;
}
move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}});
move_tile_window(out_win, {grid_rows_per_loop, number<0>{}});
move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}});
}
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,63 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace ck_tile {
struct TopkSoftmaxWarpPerRowPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
{
// TODO: Y dim must have one dim that is not reduced
return make_static_tile_distribution(
tile_distribution_encoding<
sequence<1>,
tuple<sequence<Problem::IssuesPerCol,
Problem::WarpsPerBlock,
Problem::RowsPerWarpPerColIssue>,
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 1>>,
sequence<1, 2, 2>,
sequence<0, 0, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
{
return make_static_tile_distribution(
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
tuple<sequence<Problem::IssuesPerCol,
Problem::WarpsPerBlock,
Problem::RowsPerWarpPerColIssue>,
sequence<1>>, // each row write out single element
tuple<sequence<1>, sequence<1, 0>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 0>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
{
using softmax_problem = BlockSoftmax2DProblem<typename Problem::WeightType>;
return BlockSoftmax2D<softmax_problem>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
{
using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
typename Problem::IndexType,
Problem::LanesPerRow>;
// Note: replicate is LanesPerRow
return BlockTopkStream2D<topk_problem>{};
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,46 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename InputType_,
typename WeightType_,
typename IndexType_,
index_t Experts_,
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
index_t BytesPerIssue_ = sizeof(InputType_),
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
index_t BlockSize_ = 256>
struct TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
using InputType = remove_cvref_t<InputType_>;
using WeightType = remove_cvref_t<WeightType_>;
using IndexType = remove_cvref_t<IndexType_>;
static constexpr index_t LaunchType = LaunchType_;
static constexpr index_t Experts = Experts_;
static constexpr index_t BytesPerIssue = BytesPerIssue_;
static constexpr index_t IssuesPerCol = IssuesPerCol_;
static constexpr index_t BlockSize = BlockSize_;
static constexpr index_t WarpSize = get_warp_size();
static_assert(BytesPerIssue % sizeof(InputType) == 0);
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
static_assert(Experts % VectorSize == 0);
static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize);
static_assert(WarpSize % LanesPerRow == 0);
static constexpr index_t RowsPerWarpPerColIssue = WarpSize / LanesPerRow;
static constexpr index_t RowsPerWarp = IssuesPerCol * RowsPerWarpPerColIssue;
static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize);
static constexpr index_t WarpsPerBlock = BlockSize / WarpSize;
static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock;
};
} // namespace ck_tile