mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-04-20 06:18:59 +00:00
合并fix some bugs
This commit is contained in:
0
config.json
Normal file
0
config.json
Normal file
@@ -24,8 +24,8 @@ model:
|
||||
type: balance_serve
|
||||
# type: ktransformers
|
||||
|
||||
name: SmallThinkerForCausalLM
|
||||
path: /mnt/data/models/Smallthinker-21B
|
||||
name: DeepSeek-Coder-V2-Instruct
|
||||
path: deepseek-ai/DeepSeek-V2-Lite-Chat
|
||||
gguf_path: /mnt/data/models/Smallthinker-21B
|
||||
|
||||
device: cuda:0
|
||||
|
||||
@@ -12,8 +12,13 @@ import sys
|
||||
project_dir = os.path.dirname(os.path.dirname(__file__))
|
||||
sys.path.insert(0, project_dir)
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch_npu.contrib import transfer_to_npu
|
||||
try:
|
||||
import torch_npu
|
||||
from torch_npu.contrib import transfer_to_npu
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group
|
||||
from ktransformers.util import utils, npu_graph_runner
|
||||
except:
|
||||
pass
|
||||
import torch.distributed as dist
|
||||
|
||||
import logging
|
||||
@@ -33,8 +38,7 @@ from ktransformers.models.modeling_deepseek_v3 import DeepseekV3ForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
from ktransformers.models.modeling_mixtral import MixtralForCausalLM
|
||||
from ktransformers.util.utils import prefill_and_generate, get_compute_capability, xpu_fp16_model
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group
|
||||
from ktransformers.util import utils, npu_graph_runner
|
||||
from ktransformers.util import utils
|
||||
from ktransformers.models.custom_cache import StaticCache
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.operators.flashinfer_wrapper import flashinfer_enabled
|
||||
|
||||
@@ -247,11 +247,11 @@ class KVC2StaticCache(transformers.Cache):
|
||||
Static Cache class connect with KVC2
|
||||
remind: page_idx & page_offset info need to refs to forward batching, only contains KV Block Tensor here
|
||||
"""
|
||||
def __init__(self, config: PretrainedConfig, max_batch_size, page_size: int = 256, dtype=torch.bfloat16, device=torch.device("npu:0")) -> None:
|
||||
def __init__(self, config: PretrainedConfig, max_batch_size, page_size: int = 256, dtype=torch.bfloat16, device=None) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
self.device = torch.device("npu:0")
|
||||
self.kv_lora_rank = config.kv_lora_rank
|
||||
self.max_batch_size = max_batch_size
|
||||
self.page_size = page_size
|
||||
@@ -361,7 +361,7 @@ class KVC2StaticCache(transformers.Cache):
|
||||
self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1] # page_len * page_size
|
||||
|
||||
# todo-luo 这个 get_page_table 和 另外连个类的入参不一样
|
||||
def get_page_table(self, mini_batch: ForwardMiniBatchSplit, bsz_tensors: torch.tensor = None, is_prefill=True):
|
||||
def get_page_table(self, mini_batch, bsz_tensors: torch.tensor = None, is_prefill=True):
|
||||
if is_prefill:
|
||||
# TODO add padding support
|
||||
q_lens = [mini_batch.p_q_len[idx] for idx in range(mini_batch.prefill_batch)]
|
||||
|
||||
@@ -37,8 +37,11 @@ def inject(module, local_optimization_dict, model_config:AutoConfig ,gguf_loader
|
||||
gguf_loader.tensor_device_map[inject_module_meta["key"]] = inject_module_meta["kwargs"] if "kwargs" in inject_module_meta else dict()
|
||||
import_class_name = import_path[-1]
|
||||
module_cls=getattr(__import__(import_module_name, fromlist=[""]), import_class_name)
|
||||
print(f"Injecting {child_prefix} as", import_module_name, ".",
|
||||
import_class_name) if torch.distributed.get_rank() == 0 else None
|
||||
if use_torch_npu:
|
||||
print(f"Injecting {child_prefix} as", import_module_name, ".",
|
||||
import_class_name) if torch.distributed.get_rank() == 0 else None #TODO 分布式
|
||||
else:
|
||||
print(f"Injecting {child_prefix} as", import_module_name, ".", import_class_name)
|
||||
inject_module=module_cls(key = inject_module_meta["key"], gguf_loader = gguf_loader, config = model_config, orig_module=child, **inject_module_meta["kwargs"])
|
||||
set_module(module, name, inject_module)
|
||||
elif inject_module_meta["class"] == "default":
|
||||
@@ -63,7 +66,8 @@ def del_meta(module:nn.Module):
|
||||
|
||||
def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, prefix: str="", default_device: str = "cuda:0"):
|
||||
module_name = prefix[:-1]
|
||||
translated_name = translate_name_to_gguf(prefix)[:-1]
|
||||
if use_torch_npu:
|
||||
module_name = translate_name_to_gguf(prefix)[:-1] #TODO 主仓中没有使用此变量
|
||||
recursive = True
|
||||
for rule in rule_list:
|
||||
match_meta = rule["match"]
|
||||
@@ -84,7 +88,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
|
||||
if "replace" in rule:
|
||||
replace_meta = rule["replace"]
|
||||
if module_name not in out_data:
|
||||
out_data[module_name]={"key": translated_name,
|
||||
out_data[module_name]={"key": module_name,
|
||||
"class": replace_meta["class"] if "class" in replace_meta else "default",
|
||||
# "device": replace_meta["device"] if "device" in replace_meta else default_device,
|
||||
"kwargs": copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()}
|
||||
@@ -99,7 +103,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
|
||||
if module_name not in out_data:
|
||||
out_data[module_name]= {
|
||||
"class": "default",
|
||||
"key": translated_name,
|
||||
"key": module,
|
||||
"kwargs": {"generate_device": default_device,
|
||||
"prefill_device": default_device}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any, AsyncIterator, List, Optional, Set
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache, KDeepSeekV3Cache, KGQACache
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoConfig,
|
||||
@@ -40,6 +39,7 @@ except:
|
||||
use_torch_npu = False
|
||||
if use_torch_npu:
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel, get_tensor_parallel_group, get_tensor_parallel_size
|
||||
|
||||
from ktransformers.models.modeling_deepseek import DeepseekV2ForCausalLM
|
||||
from ktransformers.models.modeling_llama import LlamaForCausalLM
|
||||
@@ -50,8 +50,8 @@ custom_models = {
|
||||
"Qwen2MoeForCausalLM": Qwen2MoeForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"MixtralForCausalLM": MixtralForCausalLM,
|
||||
}
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner
|
||||
} #TODO 独有?
|
||||
from ktransformers.server.balance_serve.inference.model_runner import ModelRunner, get_or_create_model_runner #TODO get_or_create_model_runner npu独有?
|
||||
from ktransformers.models.configuration_qwen3_next import Qwen3NextConfig
|
||||
from ktransformers.server.balance_serve.inference.sampling.sampler import Sampler, SamplingOptions
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
@@ -214,18 +214,20 @@ class Engine:
|
||||
self.model = KQwen3NextForCausalLM(config, self.cache)
|
||||
|
||||
context = zmq.Context()
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if use_torch_npu:
|
||||
if torch.distributed.get_rank() == 0:
|
||||
self.pub_socket = context.socket(zmq.PUB)
|
||||
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
|
||||
self.sub_socket = None
|
||||
else:
|
||||
self.sub_socket = context.socket(zmq.SUB)
|
||||
self.sub_socket.connect(f"ipc://{broadcast_endpoint}")
|
||||
self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.pub_socket = None
|
||||
# time.sleep(1) # make sure all subscribers are ready
|
||||
else:
|
||||
self.pub_socket = context.socket(zmq.PUB)
|
||||
self.pub_socket.bind(f"ipc://{broadcast_endpoint}")
|
||||
self.sub_socket = None
|
||||
else:
|
||||
self.sub_socket = context.socket(zmq.SUB)
|
||||
self.sub_socket.connect(f"ipc://{broadcast_endpoint}")
|
||||
self.sub_socket.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
self.pub_socket = None
|
||||
# time.sleep(1) # make sure all subscribers are ready
|
||||
|
||||
|
||||
try:
|
||||
generation_config = GenerationConfig.from_pretrained(args.model_dir)
|
||||
@@ -248,11 +250,13 @@ class Engine:
|
||||
"please input the path of your gguf file(gguf file in the dir containing input gguf file must all"
|
||||
" belong to current model):"
|
||||
)
|
||||
tp_group = get_tensor_parallel_group()
|
||||
torch.distributed.barrier(group=tp_group)
|
||||
if use_torch_npu:
|
||||
tp_group = get_tensor_parallel_group()
|
||||
torch.distributed.barrier(group=tp_group)
|
||||
optimize_and_load_gguf(self.model, optimize_config_path, gguf_path, config)
|
||||
get_absort_weight(self.model, config)
|
||||
torch.distributed.barrier(group=tp_group)
|
||||
if use_torch_npu:
|
||||
get_absort_weight(self.model, config) #TODO
|
||||
torch.distributed.barrier(group=tp_group)
|
||||
self.model.generation_config = generation_config
|
||||
if self.model.generation_config.pad_token_id is None:
|
||||
self.model.generation_config.pad_token_id = self.model.generation_config.eos_token_id
|
||||
@@ -269,6 +273,7 @@ class Engine:
|
||||
|
||||
|
||||
self.block_num = inference_context.k_cache[0].size(1)
|
||||
#TODO ModelRunner 区别
|
||||
# self.model_runner = ModelRunner(self.model, self.device, self.args.use_cuda_graph, page_size = args.page_size, block_num=self.block_num)
|
||||
#@TODO add config
|
||||
if config.architectures[0] == "Qwen2MoeForCausalLM" or config.architectures[0] == "Qwen3MoeForCausalLM" or config.architectures[0] == "Glm4MoeForCausalLM" or config.architectures[0] == "SmallThinkerForCausalLM" or config.architectures[0] == "Qwen3NextForCausalLM":
|
||||
@@ -328,18 +333,26 @@ class Engine:
|
||||
continue
|
||||
# print(f"Putting token {q.generated_token} into queue for query id: {q.id}")
|
||||
try:
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if use_torch_npu:
|
||||
if torch.distributed.get_rank() == 0:
|
||||
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
|
||||
else:
|
||||
self.gen_queue.put((q.id, q.generated_token if q.decode_done == False else None), timeout=5)
|
||||
except queue.Full:
|
||||
pass#print("Queue is full after timeout; unable to put more items.")
|
||||
|
||||
if torch.distributed.get_rank() == 0:
|
||||
if use_torch_npu:
|
||||
if torch.distributed.get_rank() == 0:
|
||||
next_batch = self.sched_client.update_last_batch(self.updates)
|
||||
if next_batch.query_ids == []:
|
||||
next_batch = None
|
||||
self.pub_socket.send_pyobj(next_batch)
|
||||
else:
|
||||
next_batch = self.sub_socket.recv_pyobj()
|
||||
else:
|
||||
next_batch = self.sched_client.update_last_batch(self.updates)
|
||||
if next_batch.query_ids == []:
|
||||
next_batch = None
|
||||
self.pub_socket.send_pyobj(next_batch)
|
||||
else:
|
||||
next_batch = self.sub_socket.recv_pyobj()
|
||||
|
||||
if next_batch is not None:
|
||||
self.query_manager.add_query(next_batch)
|
||||
@@ -372,7 +385,7 @@ def init_distributed(rank: int,
|
||||
tp_size: int,
|
||||
master_addr: str = os.getenv("MASTER_ADDR", "127.0.0.1"),
|
||||
master_port: int = os.getenv("MASTER_PORT", "29500"),
|
||||
backend: str = "hccl"):
|
||||
backend: str = "hccl"): #TODO csx: 是否distribute 都只与NPU有关
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
@@ -384,7 +397,8 @@ def init_distributed(rank: int,
|
||||
|
||||
|
||||
def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank=None, world_size=None):
|
||||
init_distributed(rank, world_size, args.tp, backend="hccl")
|
||||
if use_torch_npu:
|
||||
init_distributed(rank, world_size, args.tp, backend="hccl") #TODO 同上
|
||||
import torch.distributed as dist
|
||||
engine = Engine(args, token_queue, broadcast_endpoint, kvcache_event)
|
||||
if args.use_cuda_graph:
|
||||
@@ -392,8 +406,8 @@ def run_engine(args, token_queue, broadcast_endpoint, event, kvcache_event, rank
|
||||
engine.model_runner.warmup_npu()
|
||||
else:
|
||||
engine.model_runner.warmup()
|
||||
|
||||
args.port += torch.distributed.get_rank()
|
||||
if use_torch_npu:
|
||||
args.port += torch.distributed.get_rank()
|
||||
event.set()
|
||||
engine.loop()
|
||||
|
||||
@@ -425,29 +439,39 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, trust_remote_code=True)
|
||||
self.sched_client = SchedulerClient(args.sched_port)
|
||||
self.streamer = TextStreamer(self.tokenizer)
|
||||
world_size = str(os.getenv("WORLD_SIZE", self.args.tp))
|
||||
if not isinstance(world_size, str):
|
||||
raise ValueError(f"world_size ({world_size}) must be str")
|
||||
start_events = []
|
||||
kvcache_events = []
|
||||
for rank in range(self.args.tp):
|
||||
if int(self.args.device[-1]) > 0:
|
||||
break
|
||||
if use_torch_npu:
|
||||
world_size = str(os.getenv("WORLD_SIZE", self.args.tp))
|
||||
if not isinstance(world_size, str):
|
||||
raise ValueError(f"world_size ({world_size}) must be str")
|
||||
start_events = []
|
||||
kvcache_events = []
|
||||
for rank in range(self.args.tp):
|
||||
if int(self.args.device[-1]) > 0:
|
||||
break
|
||||
|
||||
start_event = ctx.Event()
|
||||
kvcache_event = ctx.Event()
|
||||
|
||||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event,
|
||||
kvcache_event, rank, world_size))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
start_events.append(start_event)
|
||||
kvcache_events.append(kvcache_event)
|
||||
|
||||
for evt in kvcache_events:
|
||||
evt.wait()
|
||||
self._engines = processes
|
||||
else:
|
||||
start_event = ctx.Event()
|
||||
kvcache_event = ctx.Event()
|
||||
|
||||
p = ctx.Process(target=run_engine, args=(self.args, self.token_queue, self.broadcast_endpoint, start_event,
|
||||
kvcache_event, rank, world_size))
|
||||
kvcache_event))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
start_events.append(start_event)
|
||||
kvcache_events.append(kvcache_event)
|
||||
|
||||
for evt in kvcache_events:
|
||||
evt.wait()
|
||||
self._engines = processes
|
||||
|
||||
kvcache_event.wait()
|
||||
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
|
||||
if use_torch_npu:
|
||||
args.tp = input_args.tp
|
||||
@@ -486,9 +510,11 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||
sched_process.wait()
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
for evt in start_events:
|
||||
evt.wait()
|
||||
if use_torch_npu:
|
||||
for evt in start_events:
|
||||
evt.wait()
|
||||
else:
|
||||
start_event.wait()
|
||||
|
||||
def get_params(self, temperature: Optional[float] = None, top_p: Optional[float] = None,
|
||||
max_tokens: Optional[float] = None, max_completion_tokens: Optional[float] = None) -> tuple[float, float]:
|
||||
@@ -575,7 +601,7 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||
raise ValueError("local_messages should be List or str")
|
||||
if Config().user_force_think:
|
||||
token_thinks = torch.tensor([self.tokenizer.encode("<think>\n",add_special_tokens=False)],device=input_ids.device)
|
||||
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]):
|
||||
if not torch.equal(input_ids[0, -token_thinks.shape[-1]:], token_thinks[-1]): #TODO 此行新加的,考虑是否影响GPU
|
||||
input_ids = torch.cat(
|
||||
[input_ids, token_thinks], dim=1
|
||||
)
|
||||
@@ -623,7 +649,7 @@ class BalanceServeInterface(BackendInterfaceBase):
|
||||
profiler.pause_timer("decode")
|
||||
report_last_time_performance(profiler)
|
||||
yield self.streamer.end(), None
|
||||
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1:
|
||||
if profiler.get_counter('decode') >= self.args.max_new_tokens - 1: #TODO max_new_tokens传入方式不同
|
||||
yield "", "length"
|
||||
else:
|
||||
yield "", "stop"
|
||||
|
||||
@@ -15,13 +15,13 @@ import os
|
||||
try:
|
||||
import torch_npu
|
||||
use_npu = torch.npu.is_available()
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel
|
||||
except:
|
||||
use_npu = False
|
||||
from torch import nn
|
||||
from ktransformers.server.config.log import logger
|
||||
from ktransformers.optimize.optimize import optimize_and_load_gguf
|
||||
from ktransformers.models.custom_cache import StaticCache
|
||||
from ktransformers.util.ascend.ascend_utils import get_absort_weight, setup_model_parallel
|
||||
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
|
||||
from ktransformers.local_chat import custom_models, default_optimize_rules
|
||||
from ktransformers.util.utils import get_device, get_all_used_cuda_device
|
||||
|
||||
@@ -4,6 +4,11 @@ LastEditors: Xie Weiyu ervinxie@qq.com
|
||||
LastEditTime: 2024-11-26 08:12:49
|
||||
'''
|
||||
import torch
|
||||
try:
|
||||
import torch_npu
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
except:
|
||||
use_torch_npu = False
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager, QueryInfo
|
||||
from typing import Union
|
||||
@@ -195,8 +200,9 @@ class ForwardMiniBatchSplit:
|
||||
|
||||
def __init__(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo],
|
||||
prefill_s: list[int] = None, prefill_l: list[int] = None,
|
||||
device=torch.device('npu'), page_size=256, max_page_num=64,
|
||||
device=None, page_size=256, max_page_num=64,
|
||||
decode_padding_len: int = 1):
|
||||
device = torch.device('npu')
|
||||
batch_decode = len(decode_querys_info)
|
||||
# batch_prefill = len(prefill_querys_info)
|
||||
# update valid prefill batch
|
||||
@@ -295,7 +301,8 @@ class ForwardMiniBatchSplit:
|
||||
|
||||
self.bsz_tensor = torch.tensor([self.batch_size], device=device, dtype=torch.int32)
|
||||
|
||||
def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, decode_padding_len=1, device = torch.device('npu'), page_size = 256, max_page_num=64):
|
||||
def fill(self, prefill_querys_info: list[QueryInfo], decode_querys_info: list[QueryInfo], prefill_s: list[int] = None, prefill_l: list[int] = None, decode_padding_len=1, device = None, page_size = 256, max_page_num=64):
|
||||
device = torch.device('npu')
|
||||
|
||||
page_size = 128
|
||||
|
||||
@@ -439,9 +446,10 @@ class ForwardBatchInput:
|
||||
query_manager.query_map[decode_batch_idx].decode_start_time =time.time()
|
||||
decode_querys_info.append(query_manager.query_map[decode_batch_idx])
|
||||
|
||||
|
||||
minibatch = ForwardMiniBatchSplit(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)
|
||||
|
||||
if use_torch_npu:
|
||||
minibatch = ForwardMiniBatchSplit(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)
|
||||
else:
|
||||
minibatch = ForwardMiniBatchCombine(prefill_querys_info, decode_querys_info, prefill_s, prefill_l, device = query_manager.device, page_size = query_manager.page_size)
|
||||
self.minibatch = minibatch
|
||||
|
||||
@classmethod
|
||||
@@ -487,11 +495,12 @@ class ForwardBatchInput:
|
||||
|
||||
if prefill_query_length*Config().max_prefill_batch_size + len(decode_querys_info) < cuda_lens:
|
||||
decode_querys_info.append(query_info)
|
||||
|
||||
instance.minibatch = ForwardMiniBatchSplit(prefill_query_info, decode_querys_info, [0, 0],
|
||||
[prefill_active_length for _ in range(Config().max_prefill_batch_size)],
|
||||
device, page_size, decode_padding_len=decode_query_length)
|
||||
|
||||
if use_torch_npu:
|
||||
instance.minibatch = ForwardMiniBatchSplit(prefill_query_info, decode_querys_info, [0, 0],
|
||||
[prefill_active_length for _ in range(Config().max_prefill_batch_size)],
|
||||
device, page_size, decode_padding_len=decode_query_length)
|
||||
else:
|
||||
instance.minibatch = ForwardMiniBatchCombine(prefill_query_info, decode_querys_info, [0, 0], [prefill_active_length for _ in range(Config().max_prefill_batch_size)], device, page_size)
|
||||
|
||||
return instance
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import os.path
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
import queue
|
||||
import signal
|
||||
@@ -29,7 +28,6 @@ import tempfile
|
||||
from ktransformers.server.balance_serve.inference.forward_batch import (
|
||||
ForwardBatchInput, ForwardBatchOutput, ForwardMiniBatchCombine, ForwardMiniBatchSplit)
|
||||
from ktransformers.util import utils
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache
|
||||
|
||||
from ktransformers.server.config.config import Config
|
||||
from ktransformers.models.custom_modeling_deepseek_v3 import KDeepseekV3ForCausalLM
|
||||
@@ -38,7 +36,6 @@ from ktransformers.models.custom_modeling_qwen2_moe import KQwen2MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_moe import KQwen3MoeForCausalLM
|
||||
from ktransformers.models.custom_modeling_smallthinker import KSmallThinkerForCausalLM
|
||||
from ktransformers.models.custom_modeling_glm4_moe import KGlm4MoeForCausalLM
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_modeling_qwen3_next import KQwen3NextForCausalLM
|
||||
from ktransformers.server.balance_serve.inference.query_manager import QueryManager
|
||||
from ktransformers.server.balance_serve.settings import sched_ext
|
||||
@@ -46,6 +43,8 @@ from ktransformers.server.balance_serve.settings import sched_ext
|
||||
try:
|
||||
import torch_npu
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
from ktransformers.models.ascend.custom_ascend_modeling_deepseek_v3 import KNPUDeepseekV3ForCausalLM
|
||||
from ktransformers.models.custom_cache import KVC2StaticCache
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
@@ -68,11 +67,14 @@ def generate_cuda_graphs(chunk_size: int) -> list:
|
||||
return deduplicate_and_sort(base_list + multiples)
|
||||
class ModelRunner:
|
||||
"""A CudaGraphRunner runs the forward pass of a model with CUDA graph and torch.compile."""
|
||||
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM | KQwen3NextForCausalLM | KNPUDeepseekV3ForCausalLM
|
||||
if not use_torch_npu:
|
||||
model: KDeepseekV3ForCausalLM | KQwen2MoeForCausalLM | KQwen3MoeForCausalLM | KSmallThinkerForCausalLM | KGlm4MoeForCausalLM | KQwen3NextForCausalLM
|
||||
else:
|
||||
model: KNPUDeepseekV3ForCausalLM
|
||||
cache: KVC2StaticCache #TODO 只有npu适配的代码里用到,规避
|
||||
input: ForwardBatchInput | list[ForwardBatchInput]
|
||||
output: ForwardBatchOutput
|
||||
cache: KVC2StaticCache
|
||||
|
||||
|
||||
def __init__(self, model = None, cache = None, device = None, use_cuda_graph = False, max_decode_batch_size = 1, max_chunk_size = 4096, num_mini_batches: int = 1, page_size = 256, block_num = 8):
|
||||
|
||||
@@ -103,6 +105,7 @@ class ModelRunner:
|
||||
self.output = None
|
||||
self.graph_memory_pool = None
|
||||
self.cache = cache
|
||||
#TODO 删掉了一行 self.cuda_graphs = generate_cuda_graphs(Config().chunk_size) 是为何,这样下面不会影响GPU吗
|
||||
self.use_cuda_graph = use_cuda_graph
|
||||
self.debug = False
|
||||
|
||||
@@ -295,7 +298,7 @@ class ModelRunner:
|
||||
head_dim_kpe=self.model.config.qk_rope_head_dim, page_size=self.cache.page_size, causal=True,
|
||||
sm_scale=self.model.model.layers[0].self_attn.softmax_scale, q_data_type=torch.bfloat16, kv_data_type=torch.bfloat16)
|
||||
self.start_model_event.record(self.stream)
|
||||
page_idx, page_offset = self.cache.get_page_table(self.input[cuda_graph_idx].minibatch, self.bsz_tensor_buf)
|
||||
page_idx, page_offset = self.cache.get_page_table(self.input[cuda_graph_idx].minibatch, self.bsz_tensor_buf) #TODO csx minibatch
|
||||
|
||||
self.page_idx_buf[cuda_graph_idx][:num_tokens].copy_(page_idx[:num_tokens])
|
||||
self.page_offset_buf[cuda_graph_idx][:num_tokens].copy_(page_offset[:num_tokens])
|
||||
|
||||
@@ -4,7 +4,6 @@ import re
|
||||
from uuid import uuid4
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import torch.distributed
|
||||
from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
@@ -128,20 +127,21 @@ def verify_arg(args):
|
||||
def main():
|
||||
try:
|
||||
import torch_npu
|
||||
if torch.npu.is_available():
|
||||
torch.npu.config.allow_internal_format = True
|
||||
use_npu = torch.npu.is_available()
|
||||
torch.npu.config.allow_internal_format = True
|
||||
except:
|
||||
pass
|
||||
use_npu = False
|
||||
|
||||
cfg = Config()
|
||||
|
||||
arg_parser = ArgumentParser(cfg)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
verify_arg(args)
|
||||
if use_npu:
|
||||
verify_arg(args)
|
||||
|
||||
rank_id = int(os.environ["RANK"])
|
||||
args.device = args.device[:-1] + str(rank_id)
|
||||
rank_id = int(os.environ["RANK"])
|
||||
args.device = args.device[:-1] + str(rank_id)
|
||||
create_interface(config=cfg, default_args=cfg, input_args=args)
|
||||
|
||||
tp_size = args.tp
|
||||
|
||||
@@ -702,6 +702,7 @@ def translate_name_to_gguf(name):
|
||||
name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
|
||||
name = name.replace(".mlp.shared_experts_gate", ".ffn_gate_inp_shexp")
|
||||
name = name.replace(".mlp.experts", "")
|
||||
#TODO 考虑影响
|
||||
name = name.replace(".mlp.experts.ffn_down_exps", ".ffn_down_exps")
|
||||
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
|
||||
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
|
||||
|
||||
@@ -24,8 +24,6 @@ from ktransformers.util.custom_gguf import *
|
||||
from safetensors.torch import save_file
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, Any, Optional, Union
|
||||
from safetensors import safe_open
|
||||
from safetensors.torch import save_file
|
||||
|
||||
class ModelLoader(ABC):
|
||||
"""
|
||||
|
||||
@@ -11,7 +11,6 @@ import sys
|
||||
import threading
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import torch.distributed as dist
|
||||
from torch import nn
|
||||
import itertools
|
||||
@@ -27,7 +26,7 @@ from transformers import (
|
||||
EpsilonLogitsWarper,
|
||||
EtaLogitsWarper,
|
||||
)
|
||||
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
|
||||
|
||||
from ktransformers.util.custom_loader import ModelLoaderFactory, ModelLoader, SafeTensorLoader, translate_name_to_gguf
|
||||
from ktransformers.operators import base_operator
|
||||
from ktransformers.models.custom_cache import StaticCache
|
||||
@@ -50,6 +49,7 @@ _SPECULATE_STEP = 1
|
||||
try:
|
||||
import torch_npu
|
||||
use_torch_npu = torch_npu.npu.is_available()
|
||||
from ktransformers.util.ascend.ascend_utils import get_tensor_parallel_size
|
||||
except:
|
||||
use_torch_npu = False
|
||||
|
||||
@@ -177,8 +177,8 @@ def get_current_device():
|
||||
return f"cuda:{torch.npu.current_device()}"
|
||||
|
||||
def get_compute_capability(device:torch.device = None):
|
||||
if use_torch_npu:
|
||||
return 0
|
||||
# if use_torch_npu:
|
||||
# return 0
|
||||
if torch.cuda.is_available():
|
||||
if device is None:
|
||||
num_gpus = torch.cuda.device_count()
|
||||
@@ -189,6 +189,8 @@ def get_compute_capability(device:torch.device = None):
|
||||
return min_compute_capability_major
|
||||
else:
|
||||
return torch.cuda.get_device_properties(device)
|
||||
else:
|
||||
return 0 #TODO 为什么不这么写
|
||||
|
||||
def set_module(model, submodule_key, module):
|
||||
tokens = submodule_key.split('.')
|
||||
|
||||
Reference in New Issue
Block a user