diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp index 093b5dd8c3..8ca5654bdf 100644 --- a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_knopp_kernel.hpp @@ -12,7 +12,41 @@ struct SinkhornKnoppArgs int max_iterations; }; -struct SinkhornKnoppKernel +struct SinkhornKnoppKernelReduce +{ + template + CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const { + // Creating tensor descriptors, views and windows for inputs and outputs + + // Create the reduce ops + // * Reduce Op ADD for row and column sums + // * Elementwise Op EXP for exponentiation + + using ExponentiationOp = ElementwiseOp; + using AddOp = ElementwiseOp; + using DivideOp = ElementwiseOp; + + using ReduceOp = ReduceOp; + // Run the first steps iteration of the Sinkhorn-Knopp algorithm + // Exponentiate the matrix x + auto x = load_tile(...); + + // Hot loop for Sinkhorn-Knopp iterations from 1 to max_iterations + // Use BlockReduce2D for row and column sums + for (int i = 0; i <= args.max_iterations; i++) { + // 0. LOAD x + // 1. Compute row sums (REDUCE) + // 2. Divide values by row sums (SWEEP) + // 3. STORE the result of the division (in transposed format) + // 4. LOAD transposed x + // 5. Compute column sums (REDUCE) + // 6. Divide values by column sums (SWEEP) + // 7. STORE the result of the division (in transposed format) + } + } +}; + +struct SinkhornKnoppKernelDummyNonStochastic { template CK_TILE_DEVICE void operator()(const SinkhornKnoppArgs& args) const {