This commit is contained in:
Matti Eskelinen
2026-02-06 10:27:17 +00:00
parent 683250e41b
commit 2b5a5e364c
3 changed files with 1 additions and 99 deletions

View File

@@ -24,7 +24,6 @@ void sinkhorn_knopp_ref(const HostTensor<XDataType>& x_n_n,
for(index_t j = 0; j < input_n; ++j)
{
c_n_n(i, j) = exp(type_convert<ComputeDataType>(x_n_n(i, j)));
// c_n_n(i, j) = type_convert<ComputeDataType>(x_n_n(i, j));
}
}

View File

@@ -61,7 +61,7 @@ struct SinkhornKnoppShape
template <typename _InDataType,
typename _OutDataType,
typename _BlockShape,
typename _ComputeDataType = _OutDataType>
typename _ComputeDataType = float>
struct SinkhornKnoppProblem
{
using InDataType = remove_cvref_t<_InDataType>;

View File

@@ -207,101 +207,4 @@ struct SinkhornKnoppKernelReduce
}
};
template <typename Problem, typename Policy>
struct SinkhornKnoppKernelDummyNonStochastic
{
static constexpr index_t kBlockSize = Problem::BlockShape::BlockSize;
CK_TILE_HOST static constexpr auto BlockSize()
{
return is_wave32() ? kBlockSize / 2 : kBlockSize;
}
// template <typename XDistributedTensor_>
// CK_TILE_DEVICE static auto MakeYBlockTile()
// {
// constexpr auto dstr =
// make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
// XDistributedTensor_::get_tile_distribution()
// .get_static_tile_distribution_encoding(),
// sequence<0>{}));
// auto tensor = make_static_distributed_tensor<typename Problem::OutDataType>(dstr);
// return tensor;
// }
CK_TILE_DEVICE void operator()([[maybe_unused]] const SinkhornKnoppArgs& args) const
{
using S = Problem::BlockShape;
using InDataType = typename Problem::InDataType;
using OutDataType = typename Problem::OutDataType;
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
auto* p_in = static_cast<const Problem::InDataType*>(args.p_in);
auto* p_out = static_cast<Problem::OutDataType*>(args.p_out);
auto reduce_func = ck_tile::ReduceOp::Add{};
const InDataType custom_padding_value = type_convert<InDataType>(
reduce_func.GetIdentityValue<typename Problem::ComputeDataType>());
const auto in_desc =
make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
make_tuple(args.input_m, 1),
number<4>{}, // TODO: Hardcoded
// vectorization, //we should calculate it!
number<1>{});
auto buffer_view = make_buffer_view<address_space_enum::global>(
p_in, in_desc.get_element_space_size(), custom_padding_value);
const auto input_tensor =
tensor_view<decltype(buffer_view), decltype(in_desc)>{buffer_view, in_desc};
[[maybe_unused]] auto input_window =
make_tile_window(input_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
p_out,
in_desc.get_element_space_size(),
type_convert<OutDataType>(custom_padding_value));
auto out_tensor =
tensor_view<decltype(out_buffer_view), decltype(in_desc)>{out_buffer_view, in_desc};
[[maybe_unused]] auto out_window =
make_tile_window(out_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
// Dummy copy from input to output
[[maybe_unused]] auto input_tile = load_tile(input_window);
// auto out_tile = MakeYBlockTile<decltype(input_window)>();
auto out_tile = make_static_distributed_tensor<OutDataType>(
Policy::template MakeXBlockTileDistribution<Problem>());
// Set all output elements to the custom padding value.
// // Simple solution to set the whole tile to a constant //
// set_tile(out_tile, custom_padding_value);
// store_tile(out_window, out_tile);
constexpr auto y_spans = out_tile.get_distributed_spans();
sweep_tile_span(y_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(y_spans[number<1>{}], [&](auto idx1) {
constexpr auto distributed_indices = make_tuple(idx0, idx1);
out_tile(distributed_indices) = type_convert<OutDataType>(custom_padding_value);
});
});
store_tile(out_window, out_tile);
}
};
} // namespace ck_tile