update api usage (#2969)

This commit is contained in:
Xiao Song
2026-01-27 15:33:22 +08:00
committed by GitHub
parent 51f82812ec
commit 7a14467776

View File

@@ -27,10 +27,7 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import argparse
import ctypes
import os
from math import prod
from functools import partial
from typing import Optional, Tuple, Type, Union
import numpy as np
@@ -49,14 +46,9 @@ import cutlass.pipeline as pipeline
import cutlass.utils.blackwell_helpers as sm100_utils
from cutlass.cute.nvgpu import cpasync, tcgen05
from cutlass.cute.runtime import from_dlpack
from cutlass.cute.typing import Pointer, Int32, Float16, BFloat16, Float32, Uint32, Float8E4M3FN, Float8E5M2
from cutlass.cutlass_dsl import dsl_user_op, T
from cutlass._mlir.dialects import llvm, nvvm
from cutlass._mlir.dialects.nvvm import (
MemOrderKind,
MemScopeKind,
AtomicOpKind,
)
from cutlass.cute.typing import Int32, Float16, BFloat16, Float32, Float8E4M3FN, Float8E5M2
from cutlass.cutlass_dsl import T
from cutlass._mlir.dialects import llvm
try:
import nvshmem.core
@@ -1394,24 +1386,22 @@ class PersistentDenseGemmKernel:
if lane_id == 0:
res = 0
while res < self.num_ranks:
res = nvvm.load_ext(T.i32(), flag.llvm_ptr, order=MemOrderKind.RELAXED, scope=MemScopeKind.GPU)
res = cute.arch.load(flag.llvm_ptr, cutlass.Int32, sem="relaxed", scope="gpu")
cute.arch.barrier(
barrier_id=self.reduce_scatter_sync_bar_id,
number_of_threads=32 * len(self.reduce_scatter_warp_id),
)
if warp_idx == self.reduce_scatter_warp_id[0]:
if lane_id == 0:
res = nvvm.atomicrmw(
T.i32(),
AtomicOpKind.ADD,
res = cute.arch.atomic_add(
flag.llvm_ptr,
Int32(1).ir_value(),
mem_order=MemOrderKind.RELAXED,
syncscope=MemScopeKind.SYS,
Int32(1),
sem="relaxed",
scope="sys",
)
res = nvvm.load_ext(T.i32(), flag.llvm_ptr, order=MemOrderKind.RELAXED, scope=MemScopeKind.SYS)
res = cute.arch.load(flag.llvm_ptr, cutlass.Int32, sem="relaxed", scope="sys")
if res == self.num_ranks*2:
nvvm.store_ext(Int32(0), flag.llvm_ptr, order=MemOrderKind.RELAXED, scope=MemScopeKind.SYS)
cute.arch.store(flag.llvm_ptr, Int32(0), sem="relaxed", scope="sys")
tCgC_mc = thr_mma.partition_C(gC_mc)
tCpC = thr_mma.partition_C(cC)
@@ -2460,7 +2450,7 @@ def run(
)
for free_func, tensor in free_func_and_tensor_pairs:
free_func(tensor)
print(f"exec_time: {exec_time}")
print(f"exec_time: {exec_time}\n")
return exec_time # Return execution time in microseconds