mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
topk_softmax (#1592)
* topk_softmax * remove some file * fix atomix linear_offset * address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
166
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
Normal file
166
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
Normal file
@@ -0,0 +1,166 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/hip_check_error.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TopkSoftmaxHostArgs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
void* p_indices;
|
||||
index_t num_rows;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
};
|
||||
|
||||
template <typename Pipeline_>
|
||||
struct TopkSoftmaxKernel
|
||||
{
|
||||
using Pipeline = remove_cvref_t<Pipeline_>;
|
||||
using Problem = remove_cvref_t<typename Pipeline::Problem>;
|
||||
|
||||
using InputType = typename Problem::InputType;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
using IndexType = typename Problem::IndexType;
|
||||
|
||||
struct TopkSoftmaxKargs
|
||||
{
|
||||
const void* p_input;
|
||||
void* p_output;
|
||||
void* p_indices;
|
||||
index_t num_rows;
|
||||
index_t num_experts;
|
||||
index_t topk;
|
||||
index_t stride_input; // row stride for input, at least experts
|
||||
index_t stride_output; // row stride for output/indices, at least tpok
|
||||
};
|
||||
|
||||
using Kargs = TopkSoftmaxKargs;
|
||||
using Hargs = TopkSoftmaxHostArgs;
|
||||
|
||||
CK_TILE_HOST static constexpr auto GridSize(const Hargs& h)
|
||||
{
|
||||
if constexpr(Problem::LaunchType > 0)
|
||||
{
|
||||
int num_cu = [&]() {
|
||||
hipDeviceProp_t dev_prop;
|
||||
hipDevice_t dev;
|
||||
HIP_CHECK_ERROR(hipGetDevice(&dev));
|
||||
HIP_CHECK_ERROR(hipGetDeviceProperties(&dev_prop, dev));
|
||||
return dev_prop.multiProcessorCount;
|
||||
}();
|
||||
return dim3(num_cu * Problem::LaunchType);
|
||||
}
|
||||
else
|
||||
{
|
||||
const int num_warps = (h.num_rows + Problem::RowsPerWarp - 1) / Problem::RowsPerWarp;
|
||||
const int num_blocks =
|
||||
(num_warps + Problem::WarpsPerBlock - 1) / Problem::WarpsPerBlock;
|
||||
return dim3(num_blocks);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto MakeKargs(const Hargs& h)
|
||||
{
|
||||
Kargs k;
|
||||
k.p_input = h.p_input;
|
||||
k.p_output = h.p_output;
|
||||
k.p_indices = h.p_indices;
|
||||
k.num_rows = h.num_rows;
|
||||
k.num_experts = h.num_experts;
|
||||
k.topk = h.topk;
|
||||
k.stride_input = h.stride_input;
|
||||
k.stride_output = h.stride_output;
|
||||
return k;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto BlockSize() { return Problem::BlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(Kargs kargs) const
|
||||
{
|
||||
index_t block_row_id = static_cast<index_t>(blockIdx.x * Problem::RowsPerBlock);
|
||||
|
||||
if(block_row_id > kargs.num_rows)
|
||||
return;
|
||||
|
||||
index_t block_os_inp = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_input);
|
||||
index_t block_os_out = __builtin_amdgcn_readfirstlane(block_row_id * kargs.stride_output);
|
||||
index_t num_rows_rem = __builtin_amdgcn_readfirstlane(kargs.num_rows - block_row_id);
|
||||
|
||||
const auto input_window = [&]() {
|
||||
const InputType* p_input =
|
||||
reinterpret_cast<const InputType*>(kargs.p_input) + block_os_inp;
|
||||
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_input,
|
||||
make_tuple(num_rows_rem, kargs.num_experts),
|
||||
make_tuple(kargs.stride_input, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
|
||||
auto view = pad_tensor_view(
|
||||
tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
|
||||
sequence<0, 1>{}); // out-most dim no need pad(leverage oob)
|
||||
|
||||
return make_tile_window(
|
||||
view,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<Problem::Experts>{}),
|
||||
{0, 0});
|
||||
}();
|
||||
|
||||
auto output_window = [&]() {
|
||||
WeightType* p_output = reinterpret_cast<WeightType*>(kargs.p_output) + block_os_out;
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_output,
|
||||
make_tuple(num_rows_rem, kargs.topk),
|
||||
make_tuple(kargs.stride_output, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
auto view =
|
||||
pad_tensor_view(tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
|
||||
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
|
||||
// 2. we loop over topk 1-1, no need padding
|
||||
return make_tile_window(
|
||||
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
|
||||
}();
|
||||
|
||||
auto indices_window = [&]() {
|
||||
IndexType* p_indices = reinterpret_cast<IndexType*>(kargs.p_indices) + block_os_out;
|
||||
auto tmp = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_indices,
|
||||
make_tuple(num_rows_rem, kargs.topk),
|
||||
make_tuple(kargs.stride_output, 1),
|
||||
number<Problem::VectorSize>{},
|
||||
number<1>{});
|
||||
auto view =
|
||||
pad_tensor_view(tmp,
|
||||
make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}),
|
||||
sequence<0, 0>{}); // 1. out-most dim no need pad(leverage oob)
|
||||
// 2. we loop over topk 1-1, no need padding
|
||||
return make_tile_window(
|
||||
view, make_tuple(number<Problem::RowsPerBlock>{}, number<1>{}), {0, 0});
|
||||
}();
|
||||
|
||||
Pipeline{}(input_window,
|
||||
output_window,
|
||||
indices_window,
|
||||
kargs.num_rows,
|
||||
kargs.num_experts,
|
||||
kargs.topk,
|
||||
block_row_id);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
|
||||
#endif
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_ = TopkSoftmaxWarpPerRowPolicy>
|
||||
struct TopkSoftmaxWarpPerRowPipeline
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using Policy = remove_cvref_t<Policy_>;
|
||||
using WeightType = typename Problem::WeightType;
|
||||
|
||||
template <typename InputWindow, typename OutputWindow, typename IndexWindow>
|
||||
CK_TILE_DEVICE auto operator()(const InputWindow& input_window,
|
||||
OutputWindow& out_window,
|
||||
IndexWindow& idx_window,
|
||||
index_t rows,
|
||||
index_t experts,
|
||||
index_t k,
|
||||
index_t block_row_id)
|
||||
{
|
||||
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
auto inp_win = make_tile_window_linear_raw(
|
||||
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
|
||||
#else
|
||||
auto inp_win = make_tile_window_linear(
|
||||
input_window, Policy::template MakeInputDistribution<Problem>(), sequence<0, 1>{});
|
||||
#endif
|
||||
auto out_win = make_tile_window_linear(out_window.get_bottom_tensor_view(),
|
||||
out_window.get_window_lengths(),
|
||||
out_window.get_window_origin(),
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
auto idx_win = make_tile_window_linear(idx_window.get_bottom_tensor_view(),
|
||||
idx_window.get_window_lengths(),
|
||||
idx_window.get_window_origin(),
|
||||
Policy::template MakeOutputDistribution<Problem>());
|
||||
|
||||
auto softmax = Policy::template GetSoftmax<Problem>();
|
||||
auto topk = Policy::template GetTopk<Problem>();
|
||||
|
||||
const index_t grid_rows_per_loop = gridDim.x * Problem::RowsPerBlock;
|
||||
|
||||
while(1)
|
||||
{
|
||||
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
auto x =
|
||||
load_tile_raw(inp_win, number<-1>{}, bool_constant<true>{}, bool_constant<true>{});
|
||||
buffer_load_fence(number<0>{});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
#else
|
||||
auto x = load_tile(inp_win);
|
||||
#endif
|
||||
// cast and pad input data
|
||||
auto w = [&]() {
|
||||
#if 0
|
||||
auto w_ = cast_tile<WeightType>(x);
|
||||
|
||||
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
|
||||
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
w_.get_tile_distribution(), i_j_idx);
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
// set to -INF if OOB so that later softmax can work properly
|
||||
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
|
||||
: w_(i_j_idx);
|
||||
});
|
||||
});
|
||||
return w_;
|
||||
#else
|
||||
auto w_ = make_static_distributed_tensor<WeightType>(x.get_tile_distribution());
|
||||
auto w_f = [&](auto idx) {
|
||||
w_(idx) = type_convert<WeightType>(x(idx));
|
||||
const auto x_indices =
|
||||
get_x_indices_from_distributed_indices(w_.get_tile_distribution(), idx);
|
||||
const auto current_expert = x_indices.at(number<1>{});
|
||||
w_(idx) =
|
||||
current_expert >= experts ? -numeric<WeightType>::infinity() : w_(idx);
|
||||
};
|
||||
tile_sweeper ts{w_, w_f};
|
||||
ts();
|
||||
return w_;
|
||||
#endif
|
||||
}();
|
||||
|
||||
// softmax
|
||||
auto y = softmax(w);
|
||||
|
||||
topk(y, out_win, idx_win, k);
|
||||
|
||||
// check exit
|
||||
if constexpr(Problem::LaunchType == 0)
|
||||
{
|
||||
break;
|
||||
}
|
||||
else
|
||||
{
|
||||
block_row_id += grid_rows_per_loop;
|
||||
if(block_row_id >= rows)
|
||||
break;
|
||||
}
|
||||
|
||||
move_tile_window(inp_win, {grid_rows_per_loop, number<0>{}});
|
||||
move_tile_window(out_win, {grid_rows_per_loop, number<0>{}});
|
||||
move_tile_window(idx_win, {grid_rows_per_loop, number<0>{}});
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,63 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/softmax.hpp"
|
||||
#include "ck_tile/ops/topk.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct TopkSoftmaxWarpPerRowPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeInputDistribution()
|
||||
{
|
||||
// TODO: Y dim must have one dim that is not reduced
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<1>,
|
||||
tuple<sequence<Problem::IssuesPerCol,
|
||||
Problem::WarpsPerBlock,
|
||||
Problem::RowsPerWarpPerColIssue>,
|
||||
sequence<Problem::IssuesPerRow, Problem::LanesPerRow, Problem::VectorSize>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 1>>,
|
||||
sequence<1, 2, 2>,
|
||||
sequence<0, 0, 2>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeOutputDistribution()
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<Problem::LanesPerRow>, // repeat this one
|
||||
tuple<sequence<Problem::IssuesPerCol,
|
||||
Problem::WarpsPerBlock,
|
||||
Problem::RowsPerWarpPerColIssue>,
|
||||
sequence<1>>, // each row write out single element
|
||||
tuple<sequence<1>, sequence<1, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 0>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSoftmax()
|
||||
{
|
||||
using softmax_problem = BlockSoftmax2DProblem<typename Problem::WeightType>;
|
||||
return BlockSoftmax2D<softmax_problem>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetTopk()
|
||||
{
|
||||
using topk_problem = BlockTopkStream2DProblem<typename Problem::WeightType,
|
||||
typename Problem::IndexType,
|
||||
Problem::LanesPerRow>;
|
||||
// Note: replicate is LanesPerRow
|
||||
return BlockTopkStream2D<topk_problem>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,46 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InputType_,
|
||||
typename WeightType_,
|
||||
typename IndexType_,
|
||||
index_t Experts_,
|
||||
index_t IssuesPerCol_ = 2, // issue along col, to make sure block_reduce() OK
|
||||
index_t BytesPerIssue_ = sizeof(InputType_),
|
||||
index_t LaunchType_ = 0, // 0-streaming, >0, persistent #occupancy
|
||||
index_t BlockSize_ = 256>
|
||||
struct TopkSoftmaxWarpPerRowProblem
|
||||
{
|
||||
// TODO: this kernel only support warp per row
|
||||
using InputType = remove_cvref_t<InputType_>;
|
||||
using WeightType = remove_cvref_t<WeightType_>;
|
||||
using IndexType = remove_cvref_t<IndexType_>;
|
||||
|
||||
static constexpr index_t LaunchType = LaunchType_;
|
||||
static constexpr index_t Experts = Experts_;
|
||||
static constexpr index_t BytesPerIssue = BytesPerIssue_;
|
||||
static constexpr index_t IssuesPerCol = IssuesPerCol_;
|
||||
static constexpr index_t BlockSize = BlockSize_;
|
||||
static constexpr index_t WarpSize = get_warp_size();
|
||||
|
||||
static_assert(BytesPerIssue % sizeof(InputType) == 0);
|
||||
static constexpr index_t VectorSize = BytesPerIssue / sizeof(InputType);
|
||||
static_assert(Experts % VectorSize == 0);
|
||||
static constexpr index_t LanesPerRow = min(Experts / VectorSize, WarpSize);
|
||||
static_assert(WarpSize % LanesPerRow == 0);
|
||||
static constexpr index_t RowsPerWarpPerColIssue = WarpSize / LanesPerRow;
|
||||
static constexpr index_t RowsPerWarp = IssuesPerCol * RowsPerWarpPerColIssue;
|
||||
static constexpr index_t IssuesPerRow = Experts / (LanesPerRow * VectorSize);
|
||||
|
||||
static constexpr index_t WarpsPerBlock = BlockSize / WarpSize;
|
||||
static constexpr index_t RowsPerBlock = RowsPerWarp * WarpsPerBlock;
|
||||
};
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user