mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-20 12:59:49 +00:00
Merge commit '3052d7c9e6972d5ea7d2225ab78e45554ba70efd' into develop
This commit is contained in:
@@ -78,6 +78,7 @@
|
||||
#include "ck_tile/core/utility/print.hpp"
|
||||
#include "ck_tile/core/utility/random.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
||||
#include "ck_tile/core/utility/static_counter.hpp"
|
||||
#include "ck_tile/core/utility/to_sequence.hpp"
|
||||
#include "ck_tile/core/utility/transpose_vectors.hpp"
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -18,16 +19,14 @@ struct Add
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + x;
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -46,16 +45,14 @@ struct SquareAdd
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, float, double, int32_t, int8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return y + (x * x);
|
||||
}
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
typename = std::enable_if_t<is_any_of<T, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T& y, T x) const
|
||||
{
|
||||
float y_ = type_convert<float>(y);
|
||||
@@ -66,48 +63,74 @@ struct SquareAdd
|
||||
|
||||
struct Max
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::lowest();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, x);
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, x);
|
||||
if(x > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
struct AbsMax
|
||||
{
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE static constexpr T GetIdentityValue()
|
||||
{
|
||||
return numeric<T>::zero();
|
||||
};
|
||||
|
||||
template <typename T,
|
||||
typename = std::enable_if_t<std::is_same_v<T, float> || std::is_same_v<T, double> ||
|
||||
std::is_same_v<T, int32_t> || std::is_same_v<T, int8_t> ||
|
||||
std::is_same_v<T, half_t> || std::is_same_v<T, bf16_t> ||
|
||||
std::is_same_v<T, fp8_t> || std::is_same_v<T, bf8_t>>>
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x) const
|
||||
{
|
||||
return max(y, abs(x));
|
||||
}
|
||||
|
||||
// Overload with changed flag for index tracking
|
||||
template <
|
||||
typename T,
|
||||
typename = std::enable_if_t<
|
||||
is_any_of<T, float, double, int32_t, int8_t, half_t, bf16_t, fp8_t, bf8_t>::value>>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(const T& y, const T x, bool& changed) const
|
||||
{
|
||||
T new_max = max(y, abs(x));
|
||||
if(abs(x) > y)
|
||||
{
|
||||
changed = true;
|
||||
}
|
||||
return new_max;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ReduceOp
|
||||
|
||||
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
50
include/ck_tile/core/utility/reduce_operator_accumulate.hpp
Normal file
@@ -0,0 +1,50 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/config.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/// @brief Accumulate with index tracking reductions, provides deterministic first occurring index
|
||||
struct AccumulateWithIndex
|
||||
{
|
||||
template <typename ReduceOp, typename T, typename IndexType>
|
||||
CK_TILE_HOST_DEVICE void operator()(const ReduceOp& reduce_func,
|
||||
T& current_value,
|
||||
IndexType& current_index,
|
||||
const T& new_value,
|
||||
const IndexType& new_index) const
|
||||
{
|
||||
bool changed = false;
|
||||
current_value = reduce_func(current_value, new_value, changed);
|
||||
|
||||
if(changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
else if(new_index < current_index)
|
||||
{
|
||||
bool reverse_changed = false;
|
||||
reduce_func(new_value, current_value, reverse_changed);
|
||||
|
||||
if(!reverse_changed)
|
||||
{
|
||||
current_index = new_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct Accumulate
|
||||
{
|
||||
template <typename ReduceOp, typename T>
|
||||
CK_TILE_HOST_DEVICE void
|
||||
operator()(const ReduceOp& reduce_func, T& current_value, const T& new_value) const
|
||||
{
|
||||
current_value = reduce_func(current_value, new_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -7,17 +7,21 @@
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck_tile/ops/pooling/kernel/pool_kernel.hpp"
|
||||
#include <thread>
|
||||
#include <cmath>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
@@ -45,6 +49,8 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
auto f = [&](auto n, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t y = 0; y < Y; ++y)
|
||||
{
|
||||
// Calculate input height index with stride, dilation, and padding
|
||||
@@ -58,13 +64,32 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
if(hi >= 0 && hi < H && wi >= 0 && wi < W)
|
||||
{
|
||||
const ComputeDataType v_in = type_convert<ComputeDataType>(input(n, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index = input.GetOffsetFromMultiIndex(n, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
}
|
||||
|
||||
output(n, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
output_index(n, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
@@ -74,11 +99,14 @@ CK_TILE_HOST void reference_pool2d(const HostTensor<InDataType>& input,
|
||||
template <typename InDataType,
|
||||
typename ComputeDataType,
|
||||
typename OutDataType,
|
||||
typename IndexDataType,
|
||||
typename ReduceOp,
|
||||
typename TensorShape,
|
||||
typename WindowShape>
|
||||
typename WindowShape,
|
||||
bool OutputIndex = false>
|
||||
CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
HostTensor<OutDataType>& output,
|
||||
HostTensor<IndexDataType>& output_index,
|
||||
PoolKernelArgs<TensorShape, WindowShape> kargs,
|
||||
ReduceOp reduce_op)
|
||||
{
|
||||
@@ -112,6 +140,8 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
auto f = [&](auto n, auto do_, auto ho, auto wo, auto c) {
|
||||
ComputeDataType v_acc = reduce_op.template GetIdentityValue<ComputeDataType>();
|
||||
|
||||
IndexDataType current_index = 0; // Declare outside if constexpr for efficiency
|
||||
|
||||
for(ck_tile::index_t z = 0; z < Z; ++z)
|
||||
{
|
||||
// Calculate input depth index with stride, dilation, and padding
|
||||
@@ -131,7 +161,22 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
{
|
||||
const ComputeDataType v_in =
|
||||
type_convert<ComputeDataType>(input(n, di, hi, wi, c));
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
IndexDataType flat_index =
|
||||
input.GetOffsetFromMultiIndex(n, di, hi, wi, c);
|
||||
bool changed = false;
|
||||
v_acc = reduce_op(v_acc, v_in, changed);
|
||||
if(changed)
|
||||
{
|
||||
current_index = flat_index;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_acc = reduce_op(v_acc, v_in);
|
||||
}
|
||||
}
|
||||
// For positions outside bounds, we implicitly use identity value
|
||||
}
|
||||
@@ -139,10 +184,15 @@ CK_TILE_HOST void reference_pool3d(const HostTensor<InDataType>& input,
|
||||
}
|
||||
|
||||
output(n, do_, ho, wo, c) = ck_tile::type_convert<OutDataType>(v_acc);
|
||||
|
||||
if constexpr(OutputIndex)
|
||||
{
|
||||
|
||||
output_index(n, do_, ho, wo, c) = current_index;
|
||||
}
|
||||
};
|
||||
|
||||
// Parallelize over all output dimensions
|
||||
make_ParallelTensorFunctor(f, N, Do, Ho, Wo, C)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -17,6 +17,7 @@ struct PoolHostArgs
|
||||
|
||||
CK_TILE_HOST PoolHostArgs(const void* input_ptr_,
|
||||
void* output_ptr_,
|
||||
void* output_index_ptr_,
|
||||
TensorShape input_shape_,
|
||||
TensorShape output_shape_,
|
||||
TensorShape input_strides_,
|
||||
@@ -28,6 +29,7 @@ struct PoolHostArgs
|
||||
WindowShape input_right_pads_)
|
||||
: input_ptr(input_ptr_),
|
||||
output_ptr(output_ptr_),
|
||||
output_index_ptr(output_index_ptr_),
|
||||
input_shape(input_shape_),
|
||||
output_shape(output_shape_),
|
||||
input_strides(input_strides_),
|
||||
@@ -42,6 +44,7 @@ struct PoolHostArgs
|
||||
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
@@ -60,6 +63,7 @@ struct PoolKernelArgs
|
||||
{
|
||||
const void* input_ptr;
|
||||
void* output_ptr;
|
||||
void* output_index_ptr;
|
||||
TensorShape input_shape;
|
||||
TensorShape output_shape;
|
||||
TensorShape input_strides;
|
||||
@@ -80,6 +84,7 @@ struct PoolKernel
|
||||
using InDataType = ck_tile::remove_cvref_t<typename Problem::InDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using OutDataType = ck_tile::remove_cvref_t<typename Problem::OutDataType>;
|
||||
using IndexDataType = ck_tile::remove_cvref_t<typename Problem::IndexDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
|
||||
|
||||
@@ -205,7 +210,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename TensorShape, typename WindowShape>
|
||||
@@ -338,7 +359,23 @@ struct PoolKernel
|
||||
tensor_view<decltype(out_buffer_view), decltype(out_desc_padded)>{out_buffer_view,
|
||||
out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded);
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
auto out_index_buffer_view = make_buffer_view<address_space_enum::global>(
|
||||
static_cast<IndexDataType*>(kargs.output_index_ptr),
|
||||
out_desc.get_element_space_size(),
|
||||
IndexDataType(-1));
|
||||
const auto out_index_tensor_padded =
|
||||
tensor_view<decltype(out_index_buffer_view), decltype(out_desc_padded)>{
|
||||
out_index_buffer_view, out_desc_padded};
|
||||
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, out_index_tensor_padded);
|
||||
}
|
||||
else
|
||||
{
|
||||
// Return a dummy tensor for the third element when index output is not needed
|
||||
return make_tuple(in_tensor_padded, out_tensor_padded, null_tensor{});
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
@@ -354,7 +391,7 @@ struct PoolKernel
|
||||
const auto iM = get_block_id() * S::Block_M;
|
||||
|
||||
// Get tensors based on dimensionality
|
||||
auto [in_tensor_padded, out_tensor_padded] = [&]() {
|
||||
auto [in_tensor_padded, out_tensor_padded, out_index_tensor_padded] = [&]() {
|
||||
if constexpr(WindowShape::size() == 2)
|
||||
return MakeTensorView2D(kargs);
|
||||
else if constexpr(WindowShape::size() == 3)
|
||||
@@ -387,16 +424,57 @@ struct PoolKernel
|
||||
auto y_tile = block_reduce2d.template MakeYBlockTile<XTensorTile>();
|
||||
set_tile(y_tile, reduce_op.template GetIdentityValue<ComputeDataType>());
|
||||
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
if constexpr(Problem::kOutputIndex)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
auto y_index_window =
|
||||
make_tile_window(out_index_tensor_padded, make_tuple(number<S::Block_M>{}), {iM});
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
auto y_index_tile =
|
||||
block_reduce2d.template MakeYIndexBlockTile<XTensorTile, IndexDataType>();
|
||||
set_tile(y_index_tile, IndexDataType(0));
|
||||
|
||||
// Main reduction loop - with index tracking
|
||||
for(int k_tile = amd_wave_read_first_lane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
auto index_calculator = [&](const auto& x_indices) {
|
||||
// Get global coordinates in the 2D matrix space (M, N)
|
||||
const auto global_M = x_indices.at(number<0>{}) + iM;
|
||||
const auto global_N = (k_tile * S::Block_N) + x_indices.at(number<1>{});
|
||||
return in_tensor_padded.get_tensor_descriptor().calculate_offset(
|
||||
make_tuple(global_M, global_N));
|
||||
};
|
||||
|
||||
block_reduce2d(x_tile, y_tile, y_index_tile, reduce_op, index_calculator);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, y_index_tile, reduce_op);
|
||||
if constexpr(Problem::kNeedCrossWarpSync)
|
||||
{
|
||||
__shared__ char smem_indices[Policy::template GetIndicesSmemSize<Problem>()];
|
||||
|
||||
block_reduce2d_cross_warp(y_tile, y_index_tile, smem, smem_indices, reduce_op);
|
||||
}
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
store_tile(y_index_window, cast_tile<IndexDataType>(y_index_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
// Main reduction loop - without index tracking
|
||||
for(int k_tile = __builtin_amdgcn_readfirstlane(0); k_tile < num_k_tiles; ++k_tile)
|
||||
{
|
||||
const auto x_tile = load_tile(x_window);
|
||||
block_reduce2d(x_tile, y_tile, reduce_op);
|
||||
move_tile_window(x_window, {0, S::Block_N});
|
||||
}
|
||||
|
||||
block_reduce2d_sync(y_tile, reduce_op);
|
||||
block_reduce2d_cross_warp(y_tile, smem, reduce_op);
|
||||
|
||||
store_tile(y_window, cast_tile<OutDataType>(y_tile));
|
||||
}
|
||||
}
|
||||
|
||||
/// @brief Validates if the given arguments are supported by the pooling kernel.
|
||||
@@ -481,6 +559,7 @@ struct PoolKernel
|
||||
{
|
||||
return PoolKernelArgs<TensorShape, WindowShape>{host_args.input_ptr,
|
||||
host_args.output_ptr,
|
||||
host_args.output_index_ptr,
|
||||
host_args.input_shape,
|
||||
host_args.output_shape,
|
||||
host_args.input_strides,
|
||||
|
||||
@@ -32,7 +32,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
@@ -41,7 +42,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -50,7 +52,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -61,7 +64,8 @@ struct PoolDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
@@ -76,5 +80,25 @@ struct PoolDefaultPolicy
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
|
||||
using P_ = BlockReduce2dProblem<typename Problem::InDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::InDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_index_block_tile = decltype(block_reduce2d::template MakeYIndexBlockTile<
|
||||
x_block_tile,
|
||||
typename Problem::IndexDataType>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>()
|
||||
.template GetIndicesSmemSize<y_index_block_tile>();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -26,6 +26,8 @@ struct PoolProblem
|
||||
using OutputIndex = bool_constant<OutputIndex_>;
|
||||
using PropagateNan = bool_constant<PropagateNan_>;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
static constexpr bool kPropagateNan = PropagateNan_;
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/utility/reduce_operator_accumulate.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
@@ -50,6 +51,53 @@ struct BlockReduce2d
|
||||
|
||||
CK_TILE_DEVICE constexpr BlockReduce2d() {}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename IndexCalculatorFunc,
|
||||
typename ReducePacksPerXDim>
|
||||
CK_TILE_DEVICE void reduce_impl(const XDistributedTensor_& x_tensor,
|
||||
YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
const IndexCalculatorFunc& index_calculator,
|
||||
ReducePacksPerXDim)
|
||||
{
|
||||
sweep_tile<XDistributedTensor_>(
|
||||
[&](auto... idx_) {
|
||||
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
|
||||
|
||||
(..., [&](auto idx) {
|
||||
auto val = ck_tile::type_convert<ComputeDataType>(x_tensor[idx]);
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
|
||||
const auto x_indices = get_x_indices_from_distributed_indices(
|
||||
XDistributedTensor_::get_tile_distribution(), idx);
|
||||
const auto new_idx = index_calculator(x_indices);
|
||||
auto current_idx = y_index_tensor(idx_0);
|
||||
|
||||
AccumulateWithIndex{}(
|
||||
reduce_func, y_tensor(idx_0), current_idx, val, new_idx);
|
||||
|
||||
y_index_tensor(idx_0) =
|
||||
type_convert<typename YIndexDistributedTensor_::DataType>(current_idx);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, y_tensor(idx_0), val);
|
||||
}
|
||||
}(idx_));
|
||||
},
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
public:
|
||||
// Overload for non-index tracking
|
||||
template <
|
||||
typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
@@ -61,13 +109,36 @@ struct BlockReduce2d
|
||||
const ReduceFunc& reduce_func,
|
||||
ReducePacksPerXDim = {})
|
||||
{
|
||||
sweep_tile<XDistributedTensor_>(
|
||||
[&](auto... idx_) {
|
||||
constexpr auto idx_0 = make_tuple(make_tuple(idx_[number<0>{}]...)[number<0>{}]);
|
||||
y_tensor(idx_0) = reduce_func(
|
||||
y_tensor(idx_0), ck_tile::type_convert<ComputeDataType>(x_tensor[idx_])...);
|
||||
},
|
||||
reduce_impl<false>(
|
||||
x_tensor,
|
||||
y_tensor,
|
||||
y_tensor, // dummy
|
||||
reduce_func,
|
||||
[](auto) { return 0; }, // dummy
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
// Overload for index tracking
|
||||
template <typename XDistributedTensor_,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc,
|
||||
typename IndexCalculatorFunc,
|
||||
typename ReducePacksPerXDim = uniform_sequence_gen_t<2, 1>>
|
||||
CK_TILE_DEVICE void operator()(const XDistributedTensor_& x_tensor,
|
||||
YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func,
|
||||
const IndexCalculatorFunc& index_calculator,
|
||||
ReducePacksPerXDim = {})
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(x_tensor,
|
||||
y_tensor,
|
||||
y_index_tensor,
|
||||
reduce_func,
|
||||
index_calculator,
|
||||
ReducePacksPerXDim{});
|
||||
}
|
||||
|
||||
#if 0
|
||||
constexpr auto I0 = number<0>{};
|
||||
@@ -90,7 +161,6 @@ struct BlockReduce2d
|
||||
y_tensor(y_dstr_idx) = y;
|
||||
});
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_>
|
||||
CK_TILE_DEVICE static auto MakeYBlockTile()
|
||||
@@ -111,6 +181,25 @@ struct BlockReduce2d
|
||||
return tensor;
|
||||
}
|
||||
|
||||
template <typename XDistributedTensor_, typename IndexDataType = index_t>
|
||||
CK_TILE_DEVICE static auto MakeYIndexBlockTile()
|
||||
{
|
||||
static_assert(std::is_same_v<XDataType, typename XDistributedTensor_::DataType>, "wrong!");
|
||||
|
||||
// FIXME: hard coded to reduce 2nd axis
|
||||
constexpr auto reduce_dims = sequence<1>{};
|
||||
|
||||
constexpr auto dstr =
|
||||
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
|
||||
XDistributedTensor_::get_tile_distribution()
|
||||
.get_static_tile_distribution_encoding(),
|
||||
reduce_dims));
|
||||
|
||||
auto tensor = make_static_distributed_tensor<IndexDataType>(dstr);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
// uniform_sequence_gen_t<NSize, Value> generates sequence of NSize elements filled with Value
|
||||
// e.g., uniform_sequence_gen_t<2, 1> → {1, 1} and uniform_sequence_gen_t<3, 4> → {4, 4, 4}
|
||||
template <typename XDistributedTensor_,
|
||||
@@ -135,8 +224,14 @@ struct BlockReduce2dSync
|
||||
{
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using Dstr = typename YDistributedTensor_::StaticTileDistribution;
|
||||
using DstrEncode = typename Dstr::DstrEncode;
|
||||
@@ -157,6 +252,14 @@ struct BlockReduce2dSync
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
auto v_local = y_tensor.get_thread_buffer()[i];
|
||||
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
IndexDataType idx_local{};
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_local = y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
|
||||
// cross-lane reduce for replication
|
||||
// only reduce on R dimension correspond to lane
|
||||
// (lane id maps to this R dimension)
|
||||
@@ -183,15 +286,46 @@ struct BlockReduce2dSync
|
||||
|
||||
// pull data from remote lane
|
||||
const auto v_remote = warp_shuffle(v_local, src_lane);
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
const auto idx_remote = warp_shuffle(idx_local, src_lane);
|
||||
|
||||
AccumulateWithIndex{}(
|
||||
reduce_func, v_local, idx_local, v_remote, idx_remote);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, v_local, v_remote);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
// TODO - Do we need to broadcast to other lane?
|
||||
y_tensor.get_thread_buffer()(i) = v_local;
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i) = idx_local;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(y_tensor, y_index_tensor, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
// BlockReduce2dCrossWarpSync: Cross-warp reduction (Stage 3)
|
||||
@@ -250,15 +384,39 @@ struct BlockReduce2dCrossWarpSync
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
// return in byte - separate shared memory size calculation for indices
|
||||
template <typename YIndexDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(IndexDataType);
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices_ptr,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
IndexDataType* smem_indices = nullptr;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
|
||||
}
|
||||
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
|
||||
@@ -275,6 +433,11 @@ struct BlockReduce2dCrossWarpSync
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
// Store the i-th element of this warp's thread_buffer into SMEM
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices[smem_offset + i * num_warps] =
|
||||
y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -282,10 +445,19 @@ struct BlockReduce2dCrossWarpSync
|
||||
// We let each warp holds a duplication to do reduction.
|
||||
const index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
const index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
DataType v[num_reduce_warps];
|
||||
static_for<0, num_reduce_warps, 1>{}(
|
||||
[&](auto idx) { v[idx] = smem_ptr[i * num_warps + local_smem_os + idx]; });
|
||||
[[maybe_unused]] std::
|
||||
conditional_t<kProcessIndex, IndexDataType[num_reduce_warps], IndexDataType> idx_v;
|
||||
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto idx) {
|
||||
v[idx] = smem_ptr[i * num_warps + local_smem_os + idx];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_v[idx] = smem_indices[i * num_warps + local_smem_os + idx];
|
||||
}
|
||||
});
|
||||
|
||||
static_assert(is_power_of_two_integer(num_reduce_warps),
|
||||
"wrong! only support power of 2 reduction");
|
||||
@@ -299,14 +471,44 @@ struct BlockReduce2dCrossWarpSync
|
||||
constexpr index_t i1 = idx_ + stride;
|
||||
if constexpr(i1 < num_reduce_warps)
|
||||
{
|
||||
v[i0] = reduce_func(v[i0], v[i1]);
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
AccumulateWithIndex{}(reduce_func, v[i0], idx_v[i0], v[i1], idx_v[i1]);
|
||||
}
|
||||
else
|
||||
{
|
||||
Accumulate{}(reduce_func, v[i0], v[i1]);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i) = v[0];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i) = idx_v[0];
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(
|
||||
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -364,15 +566,39 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
return num_warps * thread_buf_size * sizeof(DataType);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
// return in byte - separate shared memory size calculation for indices
|
||||
template <typename YIndexDistributedTensor_>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
constexpr index_t thread_buf_size = YIndexDistributedTensor_::get_thread_buffer_size();
|
||||
constexpr index_t num_warps = BlockShape::BlockSize / get_warp_size();
|
||||
return num_warps * thread_buf_size * sizeof(IndexDataType);
|
||||
}
|
||||
|
||||
private:
|
||||
template <bool kProcessIndex,
|
||||
typename YDistributedTensor_,
|
||||
typename YIndexDistributedTensor_,
|
||||
typename ReduceFunc>
|
||||
CK_TILE_DEVICE void reduce_impl(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices_ptr,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
using DataType = typename YDistributedTensor_::DataType;
|
||||
using IndexDataType = typename YIndexDistributedTensor_::DataType;
|
||||
|
||||
constexpr index_t thread_buf_size = YDistributedTensor_::get_thread_buffer_size();
|
||||
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
DataType* smem_ptr = reinterpret_cast<DataType*>(smem);
|
||||
IndexDataType* smem_indices = nullptr;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices = reinterpret_cast<IndexDataType*>(smem_indices_ptr);
|
||||
}
|
||||
|
||||
const index_t lane_id = get_lane_id();
|
||||
const index_t warp_id = get_warp_id();
|
||||
constexpr auto num_reduce_warps = GetReduceWarps<YDistributedTensor_>();
|
||||
@@ -388,6 +614,11 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
{
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i) {
|
||||
smem_ptr[smem_offset + i * num_warps] = y_tensor.get_thread_buffer()[i];
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
smem_indices[smem_offset + i * num_warps] =
|
||||
y_index_tensor.get_thread_buffer()[i];
|
||||
}
|
||||
});
|
||||
}
|
||||
block_sync_lds();
|
||||
@@ -395,31 +626,86 @@ struct BlockReduce2dLinearCrossWarpSync
|
||||
// load from smem. here we let everythread to do compute :)
|
||||
index_t local_warp_id = warp_id / num_reduce_warps;
|
||||
index_t local_smem_os = local_warp_id * num_reduce_warps;
|
||||
|
||||
DataType all_scratch[thread_buf_size * num_reduce_warps];
|
||||
[[maybe_unused]] std::conditional_t<kProcessIndex,
|
||||
IndexDataType[thread_buf_size * num_reduce_warps],
|
||||
IndexDataType> all_indices;
|
||||
|
||||
// Load data from shared memory
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
static_for<0, num_reduce_warps, 1>{}([&](auto i_1) {
|
||||
all_scratch[i_0 * num_reduce_warps + i_1] =
|
||||
smem_ptr[i_0 * num_warps + local_smem_os + i_1];
|
||||
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
all_indices[i_0 * num_reduce_warps + i_1] =
|
||||
smem_indices[i_0 * num_warps + local_smem_os + i_1];
|
||||
}
|
||||
});
|
||||
});
|
||||
block_sync_lds(); // TODO: we don't need sync here
|
||||
|
||||
// Perform reduction
|
||||
static_for<0, thread_buf_size, 1>{}([&](auto i_0) {
|
||||
// TODO: use descriptor for this
|
||||
auto v_local = all_scratch[i_0 * num_reduce_warps];
|
||||
|
||||
IndexDataType idx_local{};
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
idx_local = all_indices[i_0 * num_reduce_warps];
|
||||
}
|
||||
|
||||
// further reduce mean/var
|
||||
static_for<0, num_reduce_warps - 1, 1>{}([&](auto i_1_n1) {
|
||||
constexpr auto i_1 = number<i_1_n1 + 1>{};
|
||||
const DataType v_remote = all_scratch[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
// reduce
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
const IndexDataType idx_remote = all_indices[i_0 * num_reduce_warps + i_1];
|
||||
|
||||
bool changed = false;
|
||||
v_local = reduce_func(v_local, v_remote, changed);
|
||||
if(changed)
|
||||
{
|
||||
idx_local = idx_remote;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
v_local = reduce_func(v_local, v_remote);
|
||||
}
|
||||
});
|
||||
|
||||
y_tensor.get_thread_buffer()(i_0) = v_local;
|
||||
if constexpr(kProcessIndex)
|
||||
{
|
||||
y_index_tensor.get_thread_buffer()(i_0) = idx_local;
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename YDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void
|
||||
operator()(YDistributedTensor_& y_tensor, void* smem, const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<false>(y_tensor, y_tensor, smem, nullptr, reduce_func);
|
||||
}
|
||||
|
||||
template <typename YDistributedTensor_, typename YIndexDistributedTensor_, typename ReduceFunc>
|
||||
CK_TILE_DEVICE void operator()(YDistributedTensor_& y_tensor,
|
||||
YIndexDistributedTensor_& y_index_tensor,
|
||||
void* smem,
|
||||
void* smem_indices,
|
||||
const ReduceFunc& reduce_func)
|
||||
{
|
||||
reduce_impl<Problem::kOutputIndex>(
|
||||
y_tensor, y_index_tensor, smem, smem_indices, reduce_func);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -7,12 +7,17 @@
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_, typename ComputeDataType_, typename BlockShape_>
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename BlockShape_,
|
||||
bool OutputIndex_ = false>
|
||||
struct BlockReduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -32,7 +32,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2d<P_>{};
|
||||
}
|
||||
|
||||
@@ -41,7 +42,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -50,7 +52,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
return BlockReduce2dCrossWarpSync<P_>{};
|
||||
}
|
||||
|
||||
@@ -61,7 +64,8 @@ struct Reduce2dDefaultPolicy
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape>;
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile =
|
||||
@@ -76,5 +80,23 @@ struct Reduce2dDefaultPolicy
|
||||
return 1; // zero size arrays are an extension
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetIndicesSmemSize()
|
||||
{
|
||||
using P_ = BlockReduce2dProblem<typename Problem::XDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
typename Problem::BlockShape,
|
||||
Problem::kOutputIndex>;
|
||||
|
||||
using block_reduce2d = BlockReduce2d<P_>;
|
||||
using x_block_tile = decltype(make_static_distributed_tensor<typename Problem::XDataType>(
|
||||
MakeXBlockTileDistribution<Problem>()));
|
||||
using y_index_block_tile =
|
||||
decltype(block_reduce2d::template MakeYIndexBlockTile<x_block_tile, index_t>());
|
||||
|
||||
return GetBlockReduce2dCrossWarpSync<Problem>()
|
||||
.template GetIndicesSmemSize<y_index_block_tile>();
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
|
||||
@@ -11,7 +11,8 @@ template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ReduceOp_>
|
||||
typename ReduceOp_,
|
||||
bool OutputIndex_ = false>
|
||||
struct Reduce2dProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
@@ -20,6 +21,7 @@ struct Reduce2dProblem
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ReduceOp = ReduceOp_;
|
||||
|
||||
static constexpr bool kOutputIndex = OutputIndex_;
|
||||
static constexpr bool kNeedCrossLaneSync = BlockShape::ThreadPerWarp_N > 1;
|
||||
static constexpr bool kNeedCrossWarpSync = BlockShape::WarpPerBlock_N > 1;
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user