mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-25 08:48:52 +00:00
[ADD] support multi-gpu qlen>1 q5_k
This commit is contained in:
17
setup.py
17
setup.py
@@ -6,7 +6,7 @@ Author : chenxl
|
||||
Date : 2024-07-27 16:15:27
|
||||
Version : 1.0.0
|
||||
LastEditors : chenxl
|
||||
LastEditTime : 2024-07-31 09:44:46
|
||||
LastEditTime : 2024-08-08 02:45:15
|
||||
Adapted from:
|
||||
https://github.com/Dao-AILab/flash-attention/blob/v2.6.3/setup.py
|
||||
Copyright (c) 2023, Tri Dao.
|
||||
@@ -19,6 +19,7 @@ import re
|
||||
import ast
|
||||
import subprocess
|
||||
import platform
|
||||
import shutil
|
||||
import http.client
|
||||
import urllib.request
|
||||
import urllib.error
|
||||
@@ -27,6 +28,7 @@ from packaging.version import parse
|
||||
import torch.version
|
||||
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
|
||||
from setuptools import setup, Extension
|
||||
from cpufeature.extension import CPUFeature
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
|
||||
class CpuInstructInfo:
|
||||
@@ -67,6 +69,8 @@ class VersionInfo:
|
||||
"""
|
||||
if sys.platform.startswith("linux"):
|
||||
return f'linux_{platform.uname().machine}'
|
||||
elif sys.platform == "win32":
|
||||
return "win_amd64"
|
||||
else:
|
||||
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||
|
||||
@@ -97,6 +101,15 @@ class VersionInfo:
|
||||
return 'avx2'
|
||||
raise ValueError(
|
||||
"Unsupported cpu Instructions: {}".format(flags_line))
|
||||
elif sys.platform == "win32":
|
||||
if CPUFeature.get("AVX512bw", False):
|
||||
return 'fancy'
|
||||
if CPUFeature.get("AVX512f", False):
|
||||
return 'avx512'
|
||||
if CPUFeature.get("AVX2", False):
|
||||
return 'avx2'
|
||||
raise ValueError(
|
||||
"Unsupported cpu Instructions: {}".format(str(CPUFeature)))
|
||||
else:
|
||||
raise ValueError("Unsupported platform: {}".format(sys.platform))
|
||||
|
||||
@@ -154,7 +167,7 @@ class BuildWheelsCommand(_bdist_wheel):
|
||||
|
||||
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
|
||||
print("Raw wheel path", wheel_path)
|
||||
os.rename(wheel_filename, wheel_path)
|
||||
shutil.move(wheel_filename, wheel_path)
|
||||
except (urllib.error.HTTPError, urllib.error.URLError, http.client.RemoteDisconnected):
|
||||
print("Precompiled wheel not found. Building from source...")
|
||||
# If the wheel could not be downloaded, build from source
|
||||
|
||||
Reference in New Issue
Block a user