mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Add example script on implementing the algoithm using matmul
This commit is contained in:
41
include/ck_tile/ops/sinkhorn_knopp/sinkhorn_gemm_ex.py
Normal file
41
include/ck_tile/ops/sinkhorn_knopp/sinkhorn_gemm_ex.py
Normal file
@@ -0,0 +1,41 @@
|
||||
def sinkhorn_knopp_ref(m, max_iter):
|
||||
for i in range(max_iter):
|
||||
m /= m.sum(axis=0, keepdims=True)
|
||||
m /= m.sum(axis=1, keepdims=True)
|
||||
return m
|
||||
|
||||
def sinkhorn_knopp_mm(m, max_iter):
|
||||
sums = np.ones(m.shape[1])
|
||||
for i in range(max_iter):
|
||||
sums = (1 / sums) @ m
|
||||
sums = m @ ( 1 / sums.T)
|
||||
return m / sums
|
||||
|
||||
if __name__=="__main__":
|
||||
import time
|
||||
import numpy as np
|
||||
REPS = 10000
|
||||
sh = (10,10)
|
||||
max_iters = 20
|
||||
|
||||
t0 = time.time()
|
||||
for _ in range(REPS):
|
||||
m = np.random.rand(*sh)
|
||||
sinkhorn_knopp_ref(m, max_iters)
|
||||
t1 = time.time()
|
||||
print(f"{t1-t0} seconds for ref")
|
||||
|
||||
t0 = time.time()
|
||||
for _ in range(REPS):
|
||||
m = np.random.rand(*sh)
|
||||
sinkhorn_knopp_mm(m, max_iters)
|
||||
t1 = time.time()
|
||||
print(f"{t1-t0} seconds for matrix multiply")
|
||||
|
||||
for _ in range(REPS):
|
||||
m = np.random.rand(*sh)
|
||||
a = sinkhorn_knopp_ref(m, max_iters)
|
||||
b = sinkhorn_knopp_mm(m, max_iters)
|
||||
if not np.isclose(a, b).all():
|
||||
print(a-b)
|
||||
print(f"results match")
|
||||
Reference in New Issue
Block a user