From 86c47d3a9dc76ece996c5e4078e5e124bd6c9d3d Mon Sep 17 00:00:00 2001 From: Matti Eskelinen Date: Mon, 19 Jan 2026 08:06:54 -0500 Subject: [PATCH] Add example script on implementing the algoithm using matmul --- .../ops/sinkhorn_knopp/sinkhorn_gemm_ex.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 include/ck_tile/ops/sinkhorn_knopp/sinkhorn_gemm_ex.py diff --git a/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_gemm_ex.py b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_gemm_ex.py new file mode 100644 index 0000000000..a5724ec5fe --- /dev/null +++ b/include/ck_tile/ops/sinkhorn_knopp/sinkhorn_gemm_ex.py @@ -0,0 +1,41 @@ +def sinkhorn_knopp_ref(m, max_iter): + for i in range(max_iter): + m /= m.sum(axis=0, keepdims=True) + m /= m.sum(axis=1, keepdims=True) + return m + +def sinkhorn_knopp_mm(m, max_iter): + sums = np.ones(m.shape[1]) + for i in range(max_iter): + sums = (1 / sums) @ m + sums = m @ ( 1 / sums.T) + return m / sums + +if __name__=="__main__": + import time + import numpy as np + REPS = 10000 + sh = (10,10) + max_iters = 20 + + t0 = time.time() + for _ in range(REPS): + m = np.random.rand(*sh) + sinkhorn_knopp_ref(m, max_iters) + t1 = time.time() + print(f"{t1-t0} seconds for ref") + + t0 = time.time() + for _ in range(REPS): + m = np.random.rand(*sh) + sinkhorn_knopp_mm(m, max_iters) + t1 = time.time() + print(f"{t1-t0} seconds for matrix multiply") + + for _ in range(REPS): + m = np.random.rand(*sh) + a = sinkhorn_knopp_ref(m, max_iters) + b = sinkhorn_knopp_mm(m, max_iters) + if not np.isclose(a, b).all(): + print(a-b) + print(f"results match") \ No newline at end of file