mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 06:48:59 +00:00
v4.4.1 update (#3079)
This commit is contained in:
@@ -170,10 +170,10 @@ class CtaNorm:
|
||||
print(f"[DSL INFO] pred = {pred.type}")
|
||||
for i in range(cute.size(tXrX, mode=[1])):
|
||||
if pred[i]:
|
||||
cute.autovec_copy(tXgX[None, i], tXrX[None, i]) # LDG.128
|
||||
cute.autovec_copy(tWgW[None, i], tWrW[None, i]) # LDG.128
|
||||
cute.autovec_copy(tXgX[None, i], tXrX[None, i]) # Global load
|
||||
cute.autovec_copy(tWgW[None, i], tWrW[None, i]) # Global load
|
||||
if cutlass.const_expr(self.norm_type == "layer"):
|
||||
cute.autovec_copy(tBgB[None, i], tBrB[None, i]) # LDG.128
|
||||
cute.autovec_copy(tBgB[None, i], tBrB[None, i]) # Global load
|
||||
if cutlass.const_expr(self.norm_type == "layer"):
|
||||
tYrY = self.apply_layernorm(tXrX, tWrW, tBrB, eps, tidx, pred)
|
||||
elif cutlass.const_expr(self.norm_type == "rms"):
|
||||
@@ -421,4 +421,4 @@ if __name__ == "__main__":
|
||||
warmup_iterations=args.warmup_iterations,
|
||||
iterations=args.iterations,
|
||||
)
|
||||
print("\nPASS")
|
||||
print("\nPASS")
|
||||
|
||||
Reference in New Issue
Block a user