mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-04-20 14:59:01 +00:00
update api usage (#2969)
This commit is contained in:
@@ -27,10 +27,7 @@
|
||||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
import argparse
|
||||
import ctypes
|
||||
import os
|
||||
from math import prod
|
||||
from functools import partial
|
||||
from typing import Optional, Tuple, Type, Union
|
||||
|
||||
import numpy as np
|
||||
@@ -49,14 +46,9 @@ import cutlass.pipeline as pipeline
|
||||
import cutlass.utils.blackwell_helpers as sm100_utils
|
||||
from cutlass.cute.nvgpu import cpasync, tcgen05
|
||||
from cutlass.cute.runtime import from_dlpack
|
||||
from cutlass.cute.typing import Pointer, Int32, Float16, BFloat16, Float32, Uint32, Float8E4M3FN, Float8E5M2
|
||||
from cutlass.cutlass_dsl import dsl_user_op, T
|
||||
from cutlass._mlir.dialects import llvm, nvvm
|
||||
from cutlass._mlir.dialects.nvvm import (
|
||||
MemOrderKind,
|
||||
MemScopeKind,
|
||||
AtomicOpKind,
|
||||
)
|
||||
from cutlass.cute.typing import Int32, Float16, BFloat16, Float32, Float8E4M3FN, Float8E5M2
|
||||
from cutlass.cutlass_dsl import T
|
||||
from cutlass._mlir.dialects import llvm
|
||||
|
||||
try:
|
||||
import nvshmem.core
|
||||
@@ -1394,24 +1386,22 @@ class PersistentDenseGemmKernel:
|
||||
if lane_id == 0:
|
||||
res = 0
|
||||
while res < self.num_ranks:
|
||||
res = nvvm.load_ext(T.i32(), flag.llvm_ptr, order=MemOrderKind.RELAXED, scope=MemScopeKind.GPU)
|
||||
res = cute.arch.load(flag.llvm_ptr, cutlass.Int32, sem="relaxed", scope="gpu")
|
||||
cute.arch.barrier(
|
||||
barrier_id=self.reduce_scatter_sync_bar_id,
|
||||
number_of_threads=32 * len(self.reduce_scatter_warp_id),
|
||||
)
|
||||
if warp_idx == self.reduce_scatter_warp_id[0]:
|
||||
if lane_id == 0:
|
||||
res = nvvm.atomicrmw(
|
||||
T.i32(),
|
||||
AtomicOpKind.ADD,
|
||||
res = cute.arch.atomic_add(
|
||||
flag.llvm_ptr,
|
||||
Int32(1).ir_value(),
|
||||
mem_order=MemOrderKind.RELAXED,
|
||||
syncscope=MemScopeKind.SYS,
|
||||
Int32(1),
|
||||
sem="relaxed",
|
||||
scope="sys",
|
||||
)
|
||||
res = nvvm.load_ext(T.i32(), flag.llvm_ptr, order=MemOrderKind.RELAXED, scope=MemScopeKind.SYS)
|
||||
res = cute.arch.load(flag.llvm_ptr, cutlass.Int32, sem="relaxed", scope="sys")
|
||||
if res == self.num_ranks*2:
|
||||
nvvm.store_ext(Int32(0), flag.llvm_ptr, order=MemOrderKind.RELAXED, scope=MemScopeKind.SYS)
|
||||
cute.arch.store(flag.llvm_ptr, Int32(0), sem="relaxed", scope="sys")
|
||||
tCgC_mc = thr_mma.partition_C(gC_mc)
|
||||
tCpC = thr_mma.partition_C(cC)
|
||||
|
||||
@@ -2460,7 +2450,7 @@ def run(
|
||||
)
|
||||
for free_func, tensor in free_func_and_tensor_pairs:
|
||||
free_func(tensor)
|
||||
print(f"exec_time: {exec_time}")
|
||||
print(f"exec_time: {exec_time}\n")
|
||||
|
||||
return exec_time # Return execution time in microseconds
|
||||
|
||||
|
||||
Reference in New Issue
Block a user