Split into dummy and reduce impls

This commit is contained in:
Matti Eskelinen
2026-01-19 08:35:29 -05:00
parent 86c47d3a9d
commit 2c7fb73c2c

View File

@@ -12,7 +12,41 @@ struct SinkhornKnoppArgs
int max_iterations;
};
struct SinkhornKnoppKernel
struct SinkhornKnoppKernelReduce
{
template <typename Problem>
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const {
// Creating tensor descriptors, views and windows for inputs and outputs
// Create the reduce ops
// * Reduce Op ADD for row and column sums
// * Elementwise Op EXP for exponentiation
using ExponentiationOp = ElementwiseOp<ExponentiationOperation>;
using AddOp = ElementwiseOp<AddOperation>;
using DivideOp = ElementwiseOp<DivideOperation>;
using ReduceOp = ReduceOp<AddOp, AddOp>;
// Run the first steps iteration of the Sinkhorn-Knopp algorithm
// Exponentiate the matrix x
auto x = load_tile(...);
// Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations
// Use BlockReduce2D for row and column sums
for (int i = 0; i <= args.max_iterations; i++) {
// 0. LOAD x
// 1. Compute row sums (REDUCE)
// 2. Divide values by row sums (SWEEP)
// 3. STORE the result of the division (in transposed format)
// 4. LOAD transposed x
// 5. Compute column sums (REDUCE)
// 6. Divide values by column sums (SWEEP)
// 7. STORE the result of the division (in transposed format)
}
}
};
struct SinkhornKnoppKernelDummyNonStochastic
{
template <typename Problem>
CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const {