v4.4.1 update (#3079)

This commit is contained in:
Junkai-Wu
2026-02-28 02:59:21 +08:00
committed by GitHub
parent c651d660d2
commit 3bb6e28d3c
13 changed files with 92 additions and 23 deletions

View File

@@ -170,10 +170,10 @@ class CtaNorm:
print(f"[DSL INFO] pred = {pred.type}")
for i in range(cute.size(tXrX, mode=[1])):
if pred[i]:
cute.autovec_copy(tXgX[None, i], tXrX[None, i]) # LDG.128
cute.autovec_copy(tWgW[None, i], tWrW[None, i]) # LDG.128
cute.autovec_copy(tXgX[None, i], tXrX[None, i]) # Global load
cute.autovec_copy(tWgW[None, i], tWrW[None, i]) # Global load
if cutlass.const_expr(self.norm_type == "layer"):
cute.autovec_copy(tBgB[None, i], tBrB[None, i]) # LDG.128
cute.autovec_copy(tBgB[None, i], tBrB[None, i]) # Global load
if cutlass.const_expr(self.norm_type == "layer"):
tYrY = self.apply_layernorm(tXrX, tWrW, tBrB, eps, tidx, pred)
elif cutlass.const_expr(self.norm_type == "rms"):
@@ -421,4 +421,4 @@ if __name__ == "__main__":
warmup_iterations=args.warmup_iterations,
iterations=args.iterations,
)
print("\nPASS")
print("\nPASS")