mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-13 01:35:45 +00:00
Updates for CUTLASS 3.5.0 (#1468)
This commit is contained in:
@@ -1654,7 +1654,7 @@ class GemmOperationBase:
|
||||
extended_name = "{core_name}_{element_a}_{element_b}_{element_acc}_{element_c}_{element_d}".format(
|
||||
element_a=DataTypeNames[self.A.element],
|
||||
element_b=DataTypeNames[self.B.element],
|
||||
element_acc=DataTypeNames[self.tile_description.math_instruction.element_accumulator],
|
||||
element_acc=DataTypeNames[self.accumulator_type()],
|
||||
element_c=DataTypeNames[self.C.element],
|
||||
element_d=DataTypeNames[self.epilogue_functor.element_output],
|
||||
core_name=self.core_name())
|
||||
|
||||
@@ -118,16 +118,18 @@ cutlass::Status ${name}_kernel_run(
|
||||
typename DeviceKernel::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{M, N, K, L}, // problem size
|
||||
A, // ptrA
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
||||
B, // ptrB
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
||||
{
|
||||
A, // ptrA
|
||||
cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)), // stride A
|
||||
B, // ptrB
|
||||
cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)), // stride B
|
||||
},
|
||||
{
|
||||
{alpha, beta},
|
||||
C, // ptrC
|
||||
cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)), // stride C
|
||||
D, // ptrD
|
||||
cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)), // stride D
|
||||
{alpha, beta},
|
||||
},
|
||||
hw_info
|
||||
};
|
||||
|
||||
@@ -232,7 +232,7 @@ _PYTORCH_GEMM_INCLUDES = {
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
""",
|
||||
}
|
||||
@@ -583,7 +583,11 @@ setup(
|
||||
'${name}_kernel.cu',
|
||||
],
|
||||
include_dirs=['${cutlass_path}/include', '${cutlass_path}/tools/util/include'],
|
||||
extra_compile_args=['-std=c++17']
|
||||
extra_compile_args={
|
||||
'cxx': ['-std=c++17'],
|
||||
'nvcc': ['-std=c++17', ${extra_compile_args}],
|
||||
},
|
||||
libraries=['cuda']
|
||||
),
|
||||
],
|
||||
cmdclass={
|
||||
@@ -593,7 +597,7 @@ setup(
|
||||
"""
|
||||
|
||||
|
||||
def _generate_setup(name: str, sourcedir: str):
|
||||
def _generate_setup(name: str, sourcedir: str, extra_compile_args: str=""):
|
||||
"""
|
||||
Generates a setup.py file for the extension
|
||||
|
||||
@@ -601,10 +605,12 @@ def _generate_setup(name: str, sourcedir: str):
|
||||
:type name: str
|
||||
:param sourcedir: directory to which generated source files should be written
|
||||
:type sourcedir: str
|
||||
:param extra_compile_args: additional arguments to pass to setup.py
|
||||
:type extra_args: str
|
||||
"""
|
||||
setup_py_file = os.path.join(sourcedir, "setup.py")
|
||||
setup_source = SubstituteTemplate(
|
||||
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH}
|
||||
_PYTORCH_SETUP_PY, {"name": name, "cutlass_path": CUTLASS_PATH, "extra_compile_args": extra_compile_args}
|
||||
)
|
||||
with open(setup_py_file, "w") as outfile:
|
||||
outfile.write(setup_source)
|
||||
@@ -696,6 +702,7 @@ def _jit(name: str, cc: int, cpp_file: str, cuda_file: str):
|
||||
os.path.join(CUTLASS_PATH, "include"),
|
||||
os.path.join(CUTLASS_PATH, "tools/util/include"),
|
||||
],
|
||||
extra_ldflags=["-lcuda"],
|
||||
verbose=(logger.level == logging.DEBUG)
|
||||
)
|
||||
return jitmodule
|
||||
@@ -759,7 +766,10 @@ def _pytorch_gemm(op, name: str, cc: int, jit: bool = False, sourcedir: str = ""
|
||||
with open(cpp_file, "w") as outfile:
|
||||
outfile.write(cpp_source)
|
||||
|
||||
_generate_setup(name, sourcedir)
|
||||
extra_compile_args = ""
|
||||
if cc == 90:
|
||||
extra_compile_args = "'--generate-code=arch=compute_90a,code=[sm_90a]'"
|
||||
_generate_setup(name, sourcedir, extra_compile_args)
|
||||
|
||||
if jit:
|
||||
return _jit(name, cc, cpp_file, cuda_file)
|
||||
|
||||
@@ -137,9 +137,9 @@ class KernelsForDataType:
|
||||
# Finally, go through all available alignment combinations and find
|
||||
# one for which all values are less than those passed in.
|
||||
key = None
|
||||
alignments = sorted([(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
|
||||
alignments = sorted([tuple(int(x) for x in k.split(" ")) for k in self.kernels_by_alignment.keys()], reverse=True)
|
||||
for align_A, align_B, align_C in alignments:
|
||||
if align_A <= alignment_A and align_B <= alignment_B and align_C <= alignment_C:
|
||||
if alignment_A % align_A == 0 and alignment_B % align_B == 0 and alignment_C % align_C == 0:
|
||||
key = f"{align_A} {align_B} {align_C}"
|
||||
break
|
||||
|
||||
|
||||
@@ -712,4 +712,4 @@ class Gemm(OperationBase):
|
||||
if sync:
|
||||
arguments.sync()
|
||||
|
||||
return arguments
|
||||
return arguments
|
||||
|
||||
Reference in New Issue
Block a user