topk_softmax (#1592)

* topk_softmax

* remove some file

* fix atomix linear_offset

* address various comment, and change sfc get_index api to static(tuple)
This commit is contained in:
carlushuang
2024-10-26 23:52:49 +08:00
committed by GitHub
parent 31bf253aeb
commit b098b71b05
41 changed files with 5603 additions and 226 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
@@ -9,43 +9,81 @@
namespace ck_tile {
template <typename ADataType, typename AccDataType, typename BDataType>
CK_TILE_HOST void reference_softmax(const HostTensor<ADataType>& a_m_n,
HostTensor<BDataType>& b_m_n)
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST void
reference_softmax(const HostTensor<InputType>& x, HostTensor<OutputType>& y, index_t dim = -1)
{
auto f = [&](auto m) {
const int N = a_m_n.mDesc.get_lengths()[1];
index_t rank = x.get_num_of_dimension();
assert(rank == y.get_num_of_dimension());
assert(dim == -1 || dim < rank);
AccDataType v_max = ck_tile::numeric<ADataType>::Lowest();
index_t target_dim = dim == -1 ? (rank - 1) : dim;
index_t softmax_len = x.get_length(target_dim);
index_t n_parallel = x.get_element_size() / softmax_len;
auto x_len = x.get_lengths();
// max
for(int n = 0; n < N; ++n)
auto f = [&](auto i_element) {
std::vector<size_t> coord = [&]() {
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--)
{
if(i == target_dim)
continue;
t_[i] = r % x_len[i];
r = r / x_len[i];
}
return t_;
}();
ComputeType v_max = -ck_tile::numeric<ComputeType>::infinity();
// compute max
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
v_max = v_max < v_a ? v_a : v_max;
auto c_ = coord;
c_[target_dim] = idx;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_max = v_max < v_x ? v_x : v_max;
}
AccDataType v_exp_sum = 0;
ComputeType v_exp_sum = static_cast<ComputeType>(0);
// sum
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
v_exp_sum += ck_tile::exp(v_a - v_max);
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
v_exp_sum += ck_tile::exp(v_x - v_max);
}
// elementwise
for(int n = 0; n < N; ++n)
for(auto idx = 0; idx < softmax_len; idx++)
{
const ADataType v_a = a_m_n(m, n);
auto c_ = coord;
c_[target_dim] = idx;
b_m_n(m, n) = ck_tile::exp(v_a - v_max) / v_exp_sum;
const ComputeType v_x = ck_tile::type_convert<ComputeType>(x(c_));
auto out = ck_tile::exp(v_x - v_max) / v_exp_sum;
y(c_) = ck_tile::type_convert<OutputType>(out);
}
};
make_ParallelTensorFunctor(f,
b_m_n.mDesc.get_lengths()[0])(std::thread::hardware_concurrency());
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
template <typename InputType, typename ComputeType, typename OutputType = ComputeType>
CK_TILE_HOST auto reference_softmax(const HostTensor<InputType>& x, index_t dim = -1)
{
HostTensor<OutputType> y(x.get_lengths(), x.get_strides());
reference_softmax<InputType, ComputeType, OutputType>(x, y, dim);
return y;
}
} // namespace ck_tile

View File

@@ -0,0 +1,124 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
#include <utility>
#include <algorithm>
namespace ck_tile {
/*
similiar to torch.topk()
x (Tensor) the input tensor.
k (int) the k in “top-k”
dim (int, optional) the dimension to sort along
largest (bool, optional) largest or smallest elements
sorted (bool, optional) elements in sorted order or not
output:
y_values
y_indices
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
*/
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST void reference_topk(const HostTensor<DataType>& x,
HostTensor<DataType>& y_values,
HostTensor<IndexType>& y_indices,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
// rank must be the same
index_t rank = x.get_num_of_dimension();
assert(rank == y_values.get_num_of_dimension());
assert(rank == y_indices.get_num_of_dimension());
assert(dim == -1 || dim < rank);
index_t topk_dim = dim == -1 ? (rank - 1) : dim;
index_t topk_src_len = x.get_length(topk_dim);
auto x_len = x.get_lengths();
assert(k <= topk_src_len);
assert(k == y_values.get_length(topk_dim) && k == y_indices.get_length(topk_dim));
index_t n_parallel = x.get_element_size() / topk_src_len;
// clang-format off
auto f = [&](auto i_element) {
std::vector<size_t> topk_coord = [&](){
std::vector<size_t> t_(rank, 0);
size_t r = i_element;
for(index_t i = rank - 1; i >= 0; i--) {
if(i == topk_dim) continue; // topk dim should be zero
t_[i] = r % x_len[i]; r = r / x_len[i];
}
return t_;
}();
using elem_t = std::pair<DataType, IndexType>;
std::vector<elem_t> q = [&](){
std::vector<elem_t> t_(topk_src_len);
for(index_t i = 0; i < topk_src_len; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
t_[i].first = x(c_); t_[i].second = i;
}
return t_;
}();
// run topk
if(largest) {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first > rhs.first; });
}
} else {
std::nth_element(q.begin(), q.begin() + k - 1, q.end(),
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
if(sorted) {
std::sort(q.begin(), q.begin() + k - 1,
[](const elem_t& lhs, const elem_t& rhs) -> bool { return lhs.first < rhs.first; });
}
}
// write out
for(index_t i = 0; i < k; i++) {
auto c_ = topk_coord; c_[topk_dim] = i;
y_values(c_) = q[i].first; y_indices(c_) = q[i].second;
}
};
// clang-format on
make_ParallelTensorFunctor(f, n_parallel)(std::thread::hardware_concurrency());
}
// TODO: if using this method, the return tensor would be dense(no stride)
template <typename DataType, typename IndexType = index_t>
CK_TILE_HOST auto reference_topk(const HostTensor<DataType>& x,
index_t k,
index_t dim = -1,
bool largest = true,
bool sorted = true)
{
auto lens = x.get_lengths();
index_t target_dim = (dim == -1) ? (lens.size() - 1) : dim;
assert(target_dim < lens.size());
assert(k <= lens[target_dim]);
lens[target_dim] = k;
HostTensor<DataType> y_values(lens);
HostTensor<IndexType> y_indices(lens);
reference_topk<DataType, IndexType>(x, y_values, y_indices, k, dim, largest, sorted);
return ck_tile::make_tuple(y_values, y_indices);
}
} // namespace ck_tile