mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 19:28:33 +00:00
274 lines
8.9 KiB
C++
274 lines
8.9 KiB
C++
// SPDX-License-Identifier: MIT
|
|
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
|
|
|
#pragma once
|
|
|
|
#include "ck/utility/common_header.hpp"
|
|
#include "ck/tensor_description/tensor_descriptor.hpp"
|
|
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
|
#include "ck/tensor_description/tensor_adaptor.hpp"
|
|
|
|
#include "ck/tile_program/tile/tile_distribution.hpp"
|
|
#include "ck/tile_program/tile/tile_window.hpp"
|
|
#include "ck/tile_program/tile/load_tile.hpp"
|
|
#include "ck/tile_program/tile/store_tile.hpp"
|
|
#include "ck/tile_program/tile/tile_elementwise.hpp"
|
|
#include "ck/tile_program/block_tile/block_reduce.hpp"
|
|
|
|
template <typename ADataType,
|
|
typename AccDataType,
|
|
typename BDataType,
|
|
ck::index_t kBlockSize,
|
|
ck::index_t kMPerBlock,
|
|
ck::index_t kNPerBlock>
|
|
struct Softmax
|
|
{
|
|
__device__ static constexpr auto MakeABlockTileDistribution()
|
|
{
|
|
using namespace ck;
|
|
using namespace ck::tile_program;
|
|
|
|
// 4x1 wave
|
|
return make_static_tile_distribution(
|
|
StaticTileDistributionEncoding<Sequence<>,
|
|
Tuple<Sequence<1, 4, 4, 2, 4>, Sequence<4, 1, 32>>,
|
|
Tuple<Sequence<1, 2>, Sequence<1, 2>>,
|
|
Tuple<Sequence<1, 1>, Sequence<3, 2>>,
|
|
Sequence<1, 2, 1, 1>,
|
|
Sequence<0, 0, 2, 4>>{});
|
|
}
|
|
|
|
__device__ void
|
|
MultiPassSoftmax(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
|
|
{
|
|
using namespace ck;
|
|
using namespace ck::tile_program;
|
|
using namespace ck::tile_program::block;
|
|
|
|
constexpr auto I0 = Number<0>{};
|
|
constexpr auto I1 = Number<1>{};
|
|
|
|
// A DRAM tensor view
|
|
const auto a_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
|
|
p_a, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
|
|
|
|
const auto iM = get_block_id() * kMPerBlock;
|
|
|
|
// A DRAM window
|
|
auto a_dram_window =
|
|
make_tile_window(a_dram,
|
|
make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}),
|
|
{iM, 0},
|
|
MakeABlockTileDistribution());
|
|
|
|
// m = rowmax(A)
|
|
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
|
|
|
auto m = decltype(block_tile_reduce<AccDataType>(
|
|
load_tile(a_dram_window), Sequence<1>{}, f_max, ADataType{})){};
|
|
|
|
tile_elementwise_inout(
|
|
[&](auto& e) { e = type_convert<AccDataType>(NumericLimits<ADataType>::Lowest()); }, m);
|
|
|
|
index_t iN = 0;
|
|
|
|
do
|
|
{
|
|
// load A tile from DRAM
|
|
const auto a = load_tile(a_dram_window);
|
|
|
|
// m = rowmax(A)
|
|
block_tile_reduce(m, a, Sequence<1>{}, f_max);
|
|
|
|
move_tile_window(a_dram_window, {0, kNPerBlock});
|
|
|
|
iN += kNPerBlock;
|
|
|
|
} while(iN < N);
|
|
|
|
// cross lane reduce: max
|
|
block_tile_reduce_sync(m, f_max);
|
|
|
|
// reset window location
|
|
iN = 0;
|
|
move_tile_window(a_dram_window, {0, -N});
|
|
|
|
// l = rowsum(exp(A - m))
|
|
auto l = make_static_distributed_tensor<AccDataType>(m.GetTileDistribution());
|
|
|
|
tile_elementwise_inout([&](auto& e) { e = 0; }, l);
|
|
|
|
do
|
|
{
|
|
// load A tile from DRAM
|
|
const auto a = load_tile(a_dram_window);
|
|
|
|
constexpr auto a_spans = decltype(a)::GetDistributedSpans();
|
|
|
|
sweep_tile_span(a_spans[I0], [&](auto idx0) {
|
|
constexpr auto i_idx = make_tuple(idx0);
|
|
|
|
sweep_tile_span(a_spans[I1], [&](auto idx1) {
|
|
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
|
|
|
// l = rowsum(exp(A - m))
|
|
l(i_idx) += math::exp(a[i_j_idx] - m[i_idx]);
|
|
});
|
|
});
|
|
|
|
move_tile_window(a_dram_window, {0, kNPerBlock});
|
|
|
|
iN += kNPerBlock;
|
|
|
|
} while(iN < N);
|
|
|
|
// cross lane reduce: sum
|
|
block_tile_reduce_sync(l, [](auto e0, auto e1) { return e0 + e1; });
|
|
|
|
// reset window location
|
|
iN = 0;
|
|
move_tile_window(a_dram_window, {0, -N});
|
|
|
|
// B DRAM tensor view
|
|
const auto b_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
|
|
p_b, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
|
|
|
|
// B DRAM window
|
|
auto b_dram_window = make_tile_window(
|
|
b_dram, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, 0});
|
|
|
|
// B = exp(A - m) / l
|
|
do
|
|
{
|
|
// load A tile from DRAM
|
|
const auto a = load_tile(a_dram_window);
|
|
|
|
constexpr auto a_spans = decltype(a)::GetDistributedSpans();
|
|
|
|
auto b = make_static_distributed_tensor<BDataType>(a.GetTileDistribution());
|
|
|
|
sweep_tile_span(a_spans[I0], [&](auto idx0) {
|
|
constexpr auto i_idx = make_tuple(idx0);
|
|
|
|
sweep_tile_span(a_spans[I1], [&](auto idx1) {
|
|
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
|
|
|
// B = exp(A - m) / l
|
|
b(i_j_idx) =
|
|
type_convert<BDataType>(math::exp(a[i_j_idx] - m[i_idx]) / l[i_idx]);
|
|
});
|
|
});
|
|
|
|
// store B tile
|
|
store_tile(b_dram_window, b);
|
|
|
|
move_tile_window(a_dram_window, {0, kNPerBlock});
|
|
move_tile_window(b_dram_window, {0, kNPerBlock});
|
|
|
|
iN += kNPerBlock;
|
|
|
|
} while(iN < N);
|
|
}
|
|
|
|
__device__ void
|
|
SinglePassSoftmax(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
|
|
{
|
|
using namespace ck;
|
|
using namespace ck::tile_program;
|
|
using namespace ck::tile_program::block;
|
|
|
|
constexpr auto I0 = Number<0>{};
|
|
constexpr auto I1 = Number<1>{};
|
|
|
|
// A DRAM tensor view
|
|
const auto a_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
|
|
p_a, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
|
|
|
|
const auto iM = get_block_id() * kMPerBlock;
|
|
|
|
// A DRAM window
|
|
auto a_dram_window =
|
|
make_tile_window(a_dram,
|
|
make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}),
|
|
{iM, 0},
|
|
MakeABlockTileDistribution());
|
|
|
|
// f_max
|
|
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
|
|
|
|
// m = rowmax(A)
|
|
auto m = decltype(block_tile_reduce<AccDataType>(
|
|
load_tile(a_dram_window), Sequence<1>{}, f_max, ADataType{})){};
|
|
|
|
tile_elementwise_inout(
|
|
[&](auto& e) { e = type_convert<AccDataType>(NumericLimits<ADataType>::Lowest()); }, m);
|
|
|
|
// l = rowsum(exp(A - m))
|
|
auto l = make_static_distributed_tensor<AccDataType>(m.GetTileDistribution());
|
|
|
|
tile_elementwise_inout([&](auto& e) { e = 0; }, l);
|
|
|
|
// load A tile from DRAM
|
|
const auto a = load_tile(a_dram_window);
|
|
|
|
constexpr auto a_spans = decltype(a)::GetDistributedSpans();
|
|
|
|
// m = rowmax(A)
|
|
block_tile_reduce(m, a, Sequence<1>{}, f_max);
|
|
|
|
// cross lane reduce: max
|
|
block_tile_reduce_sync(m, f_max);
|
|
|
|
// l = rowsum(exp(A - m))
|
|
sweep_tile_span(a_spans[I0], [&](auto idx0) {
|
|
constexpr auto i_idx = make_tuple(idx0);
|
|
|
|
sweep_tile_span(a_spans[I1], [&](auto idx1) {
|
|
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
|
|
|
l(i_idx) += math::exp(a[i_j_idx] - m[i_idx]);
|
|
});
|
|
});
|
|
|
|
// cross lane reduce: sum
|
|
block_tile_reduce_sync(l, [](auto e0, auto e1) { return e0 + e1; });
|
|
|
|
auto b = make_static_distributed_tensor<BDataType>(a.GetTileDistribution());
|
|
|
|
// B = exp(A - m) / l
|
|
sweep_tile_span(a_spans[I0], [&](auto idx0) {
|
|
constexpr auto i_idx = make_tuple(idx0);
|
|
|
|
sweep_tile_span(a_spans[I1], [&](auto idx1) {
|
|
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
|
|
|
b(i_j_idx) = type_convert<BDataType>(math::exp(a[i_j_idx] - m[i_idx]) / l[i_idx]);
|
|
});
|
|
});
|
|
|
|
// B DRAM tensor view
|
|
const auto b_dram = make_naive_tensor_view<AddressSpaceEnum::Global>(
|
|
p_b, make_tuple(M, N), make_tuple(N, 1), Number<32>{}, Number<1>{});
|
|
|
|
// B DRAM window
|
|
auto b_dram_window = make_tile_window(
|
|
b_dram, make_tuple(Number<kMPerBlock>{}, Number<kNPerBlock>{}), {iM, 0});
|
|
|
|
// store B tile
|
|
store_tile(b_dram_window, b);
|
|
}
|
|
|
|
__device__ void
|
|
operator()(const ADataType* p_a, BDataType* p_b, ck::index_t M, ck::index_t N) const
|
|
{
|
|
if(N > kNPerBlock)
|
|
{
|
|
MultiPassSoftmax(p_a, p_b, M, N);
|
|
}
|
|
else
|
|
{
|
|
SinglePassSoftmax(p_a, p_b, M, N);
|
|
}
|
|
}
|
|
};
|