WIP with checks passing

This commit is contained in:
Matti Eskelinen
2026-01-19 09:40:32 -05:00
parent 5a0fea7f5a
commit 9e79b07298
4 changed files with 93 additions and 35 deletions

View File

@@ -0,0 +1,12 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_default_policy.hpp"
#include "ck_tile/ops/sinkhorn_knopp/pipeline/sinkhorn_knopp_problem.hpp"
#include "ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
#include "ck_tile/ops/common/streamk_common.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"

View File

@@ -1,3 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
@@ -18,13 +20,14 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
tuple<sequence<2, 1>, sequence<2, 1>>,
tuple<sequence<1, 1>, sequence<2, 1>>, // WarpPerBlock_M, WarpPerBlock_N; ThreadPerWarp_M, ThreadPerWarp_N
tuple<sequence<1, 1>, sequence<2, 1>>, // WarpPerBlock_M, WarpPerBlock_N;
// ThreadPerWarp_M, ThreadPerWarp_N
// sequence<1, 1, 2, 2>,
// sequence<0, 3, 0, 3>>{}); // Repeat_N, ThreadTile_N, Repeat_M, ThreadTile_M
sequence<2, 2, 1, 1>,
sequence<0, 2, 0, 3>>{}); // Repeat_M, ThreadTile_M, Repeat_N, ThreadTile_N
}
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
@@ -34,7 +37,7 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
sequence<>,
tuple<
sequence<S::Repeat_M, S::WarpPerBlock_M, S::ThreadPerWarp_M, S::ThreadTile_M>>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
sequence<S::Repeat_N, S::WarpPerBlock_N, S::ThreadPerWarp_N, S::ThreadTile_N>,
tuple<sequence<1, 2>, sequence<1, 2>>,
tuple<sequence<1, 1>, sequence<1, 1>>,
sequence<1, 1, 2, 2>,
@@ -42,4 +45,4 @@ struct SinkhornKnoppDefaultPolicy : public Reduce2dDefaultPolicy
}
};
} // namespace ck_tile
} // namespace ck_tile

View File

@@ -1,31 +1,40 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
namespace ck_tile {
template <WarpPerBlock_M, WarpPerBlock_N, ThreadPerWarp_M, ThreadPerWarp_N, ThreadTile_M, ThreadTile_N, Repeat_M, Repeat_N>
template <WarpPerBlock_M,
WarpPerBlock_N,
ThreadPerWarp_M,
ThreadPerWarp_N,
ThreadTile_M,
ThreadTile_N,
Repeat_M,
Repeat_N>
struct SinkHornKnoppShape
{
static constexpr index_t Block_M = WarpPerBlock_M;
static constexpr index_t Block_N = WarpPerBlock_N;
static constexpr index_t Block_M = WarpPerBlock_M;
static constexpr index_t Block_N = WarpPerBlock_N;
static constexpr index_t ThreadPerWarp_M = ThreadPerWarp_M;
static constexpr index_t ThreadPerWarp_N = ThreadPerWarp_N;
static constexpr index_t ThreadTile_M = ThreadTile_M;
static constexpr index_t ThreadTile_N = ThreadTile_N;
static constexpr index_t Repeat_M = Repeat_M;
static constexpr index_t Repeat_N = Repeat_N;
static constexpr index_t ThreadTile_M = ThreadTile_M;
static constexpr index_t ThreadTile_N = ThreadTile_N;
static constexpr index_t Repeat_M = Repeat_M;
static constexpr index_t Repeat_N = Repeat_N;
};
template <typename _XDataType,
typename _YDataType,
typename _BlockShape,
typename _ComputeDataType = float>
template <typename _XDataType,
typename _YDataType,
typename _BlockShape,
typename _ComputeDataType = float>
struct SinkhornKnoppProblem
{
using XDataType = remove_cvref_t<_XDataType>;
using XDataType = remove_cvref_t<_XDataType>;
using ComputeDataType = remove_cvref_t<_ComputeDataType>;
using YDataType = remove_cvref_t<_YDataType>;
using YDataType = remove_cvref_t<_YDataType>;
using BlockShape = remove_cvref_t<_BlockShape>;
};
} // namespace ck_tile
} // namespace ck_tile

View File

@@ -1,3 +1,5 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
@@ -8,22 +10,24 @@ struct SinkhornKnoppArgs
{
void* out;
const void* p_x;
const index_t input_m;
int max_iterations;
};
struct SinkhornKnoppKernelReduce
{
template <typename Problem>
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const {
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
{
// Creating tensor descriptors, views and windows for inputs and outputs
// Create the reduce ops
// * Reduce Op ADD for row and column sums
// * Elementwise Op EXP for exponentiation
// * Reduce Op ADD for row and column sums
// * Elementwise Op EXP for exponentiation
using ExponentiationOp = ElementwiseOp<ExponentiationOperation>;
using AddOp = ElementwiseOp<AddOperation>;
using DivideOp = ElementwiseOp<DivideOperation>;
using AddOp = ElementwiseOp<AddOperation>;
using DivideOp = ElementwiseOp<DivideOperation>;
using ReduceOp = ReduceOp<AddOp, AddOp>;
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
@@ -32,7 +36,8 @@ struct SinkhornKnoppKernelReduce
// Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
// Use BlockReduce2D for row and column sums
for (int i = 0; i <= args.max_iterations; i++) {
for(int i = 0; i <= args.max_iterations; i++)
{
// 0. LOAD x
// 1. Compute row sums (REDUCE)
// 2. Divide values by row sums (SWEEP)
@@ -45,27 +50,56 @@ struct SinkhornKnoppKernelReduce
}
};
template <typename Problem, typename Policy>
struct SinkhornKnoppKernelDummyNonStochastic
{
template <typename Problem>
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const {
template <typename XDistributedTensor_>
CK_TILE_DEVICE static auto MakeYBlockTile()
{
constexpr auto dstr =
make_static_tile_distribution(detail::make_reduce_tile_distribution_encoding(
XDistributedTensor_::get_tile_distribution()
.get_static_tile_distribution_encoding(),
sequence<0>{}));
auto tensor = make_static_distributed_tensor<Problem::ComputeDataType>(dstr);
return tensor;
}
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const
{
using S = Problem::BlockShape;
auto desc = make_naive_tensor_descriptor();
static_assert(S::Block_M == S::Block_N, "Input must be a square matrix!");
const auto x_desc = make_naive_tensor_descriptor(make_tuple(args.input_m, args.input_m),
make_tuple(args.input_m, 1),
number<args.input_m>{},
number<1>{});
auto buffer_view = make_buffer_view<address_space_enum::global>(
args.p_x, desc.get_element_space_size(), number<0>{});
const auto x_tensor =
tensor_view<decltype(buffer_view), decltype(desc)>{buffer_view, desc};
tensor_view<decltype(buffer_view), decltype(x_desc)>{buffer_view, x_desc};
auto x_window =
make_tile_window(x_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto x_window = make_tile_window(x_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
auto out_buffer_view = make_buffer_view<address_space_enum::global>(
args.out, x_desc.get_element_space_size(), number<0>{});
const auto y_tensor =
tensor_view<decltype(out_buffer_view), decltype(x_desc)>{out_buffer_view, x_desc};
auto y_window = make_tile_window(y_tensor,
make_tuple(number<S::Block_M>{}, number<S::Block_N>{}),
{0, 0},
Policy::template MakeXBlockTileDistribution<Problem>());
}
};
} // namespace ck_tile
} // namespace ck_tile