Fix pytest unstable issue. (#170)

- remove `#include <cstdint>` from `poll.hpp`. To make it only contains
device-side code
- Fix compilation issue, which will cause pytest fail randomly. Reuse
the compiled result for same kernel with different arguments
This commit is contained in:
Binyang2014
2023-09-07 08:09:04 +08:00
committed by GitHub
parent 828be48b21
commit 097aa8843a
13 changed files with 49 additions and 14 deletions

View File

@@ -72,12 +72,18 @@ class Kernel:
class KernelBuilder:
kernel_map: dict = {}
def __init__(self, file: str, kernel_name: str):
self._tempdir = tempfile.TemporaryDirectory()
if kernel_name in self.kernel_map:
self._kernel = self.kernel_map[kernel_name]
return
self._tempdir = tempfile.TemporaryDirectory(suffix=f"{os.getpid()}")
self._current_file_dir = os.path.dirname(os.path.abspath(__file__))
device_id = cp.cuda.Device().id
ptx = self._compile_cuda(os.path.join(self._current_file_dir, file), f"{kernel_name}.ptx", device_id)
self._kernel = Kernel(ptx, kernel_name, device_id)
self.kernel_map[kernel_name] = self._kernel
def _compile_cuda(self, source_file, output_file, device_id, std_version="c++17"):
include_dir = os.path.join(self._current_file_dir, "../../include")
@@ -87,22 +93,34 @@ class KernelBuilder:
minor = _check_cuda_errors(
cudart.cudaDeviceGetAttribute(cudart.cudaDeviceAttr.cudaDevAttrComputeCapabilityMinor, device_id)
)
command = (
f"nvcc -std={std_version} -ptx -Xcompiler -Wall,-Wextra -I{include_dir} {source_file} "
f"--gpu-architecture=compute_{major}{minor} --gpu-code=sm_{major}{minor},compute_{major}{minor} -o {self._tempdir.name}/{output_file}"
)
cuda_home = os.environ.get("CUDA_HOME")
nvcc = os.path.join(cuda_home, "bin/nvcc") if cuda_home else "nvcc"
command = [
nvcc,
f"-std={std_version}",
"-ptx",
"-Xcompiler",
"-Wall,-Wextra",
f"-I{include_dir}",
f"{source_file}",
f"--gpu-architecture=compute_{major}{minor}",
f"--gpu-code=sm_{major}{minor},compute_{major}{minor}",
"-o",
f"{self._tempdir.name}/{output_file}",
]
try:
subprocess.run(command, shell=True, check=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
subprocess.run(command, capture_output=True, text=True, check=True, bufsize=1)
with open(f"{self._tempdir.name}/{output_file}", "rb") as f:
return f.read()
except subprocess.CalledProcessError as e:
raise RuntimeError("Compilation failed:", e.stderr.decode(), command)
raise RuntimeError("Compilation failed:", e.stderr, " ".join(command))
def get_compiled_kernel(self):
return self._kernel
def __del__(self):
self._tempdir.cleanup()
if hasattr(self, "_tempdir"):
self._tempdir.cleanup()
def pack(*args):