mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-03-14 18:37:23 +00:00
* fix: qwen3-npu bugs; update: add readme-for-qwen3-npu * fix: Correct the README description
693 lines
29 KiB
Python
693 lines
29 KiB
Python
'''
|
||
Description :
|
||
Author : Boxin Zhang
|
||
Version : 0.1.0
|
||
'''
|
||
# Adapted from
|
||
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/cache_utils.py
|
||
# Copyright 2018- The Hugging Face team. All rights reserved.
|
||
# Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
|
||
import torch
|
||
import torch.nn as nn
|
||
import transformers
|
||
from transformers import Cache, PretrainedConfig
|
||
from typing import List, Optional, Dict, Any, Tuple
|
||
|
||
try:
|
||
import torch_npu
|
||
from ktransformers.util import utils
|
||
from ktransformers.server.balance_serve.inference.forward_batch import ForwardMiniBatchCombine, ForwardMiniBatchSplit
|
||
|
||
use_torch_npu = torch_npu.npu.is_available()
|
||
except:
|
||
use_torch_npu = False
|
||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
|
||
from ktransformers.server.balance_serve.settings import sched_ext
|
||
|
||
class StaticCache(transformers.StaticCache):
|
||
"""
|
||
Static Cache class to be used with `torch.compile(model)`.
|
||
|
||
Parameters:
|
||
config (`PretrainedConfig):
|
||
The configuration file defining the shape-related attributes required to initialize the static cache.
|
||
max_batch_size (`int`):
|
||
The maximum batch size with which the model will be used.
|
||
max_cache_len (`int`):
|
||
The maximum sequence length with which the model will be used.
|
||
device (`torch.device` or `dict`):
|
||
The device on which the cache should be initialized. Should be the same as the layer.
|
||
If a `dict`, it should contain the `device` key with the device name as the value.
|
||
dtype (*optional*, defaults to `torch.float32`):
|
||
The default `dtype` to use when initializing the layer.
|
||
"""
|
||
|
||
def __init__(self, config: PretrainedConfig, max_batch_size: int, max_cache_len: int, device: torch.device| dict, dtype=None) -> None:
|
||
Cache.__init__(self, layer_class_to_replicate=LlamaDecoderLayer)
|
||
self._max_batch_size = max_batch_size
|
||
|
||
if use_torch_npu:
|
||
self.position = [0]
|
||
|
||
self._max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
|
||
# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
|
||
if config.architectures[0] == "DeepseekV3ForCausalLM":
|
||
self.head_dim = config.qk_rope_head_dim
|
||
else:
|
||
self.head_dim = (
|
||
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
|
||
)
|
||
|
||
self.dtype = dtype if dtype is not None else torch.float32
|
||
self.num_key_value_heads = (
|
||
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
|
||
)
|
||
|
||
self.key_cache: List[torch.Tensor] = []
|
||
self.value_cache: List[torch.Tensor] = []
|
||
cache_shape = (max_batch_size, self.num_key_value_heads, self._max_cache_len, self.head_dim)
|
||
if config.architectures[0] == "DeepseekV2ForCausalLM" or config.architectures[0] == "DeepseekV3ForCausalLM":
|
||
# TODO: for deepseek, cache_shape is different whether using Absorbed MLA, check it automatically
|
||
|
||
if use_torch_npu:
|
||
self.page_size = 128
|
||
self.page_size_tensor = torch.tensor(
|
||
self.page_size,
|
||
dtype=torch.int32,
|
||
).npu()
|
||
self.max_pages_per_batch = (self._max_cache_len + self.page_size - 1) // self.page_size
|
||
self.max_pages = (self._max_cache_len + self.page_size - 1) // self.page_size * self._max_batch_size
|
||
else:
|
||
self.page_size = 64
|
||
self.max_pages = (self._max_cache_len + self.page_size - 1) // self.page_size
|
||
latent_shape = (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||
self.kv_lora_rank = config.kv_lora_rank
|
||
self.qk_rope_head_dim = config.qk_rope_head_dim
|
||
# TODO: support real page table
|
||
self.page_table_map = dict()
|
||
self.page_table_list = []
|
||
for idx in range(config.num_hidden_layers):
|
||
if isinstance(device, dict):
|
||
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
|
||
else:
|
||
target_device = device
|
||
|
||
if target_device not in self.page_table_map:
|
||
if use_torch_npu:
|
||
page_table = torch.zeros((max_batch_size, self.max_pages_per_batch), dtype=torch.int32, device=target_device)
|
||
for seq_id in range(max_batch_size):
|
||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages_per_batch, seq_id * self.max_pages_per_batch + self.max_pages_per_batch, dtype=torch.int32, device=target_device)
|
||
else:
|
||
page_table = torch.zeros((max_batch_size, self.max_pages), dtype=torch.int32, device=target_device)
|
||
for seq_id in range(max_batch_size):
|
||
page_table[seq_id, :] = torch.arange(seq_id * self.max_pages, seq_id * self.max_pages + self.max_pages, dtype=torch.int32, device=target_device)
|
||
self.page_table_map[target_device] = page_table
|
||
|
||
self.page_table_list.append(self.page_table_map[target_device])
|
||
|
||
self.is_MLA = True
|
||
self.is_page = True
|
||
else:
|
||
key_shape = cache_shape
|
||
value_shape = cache_shape
|
||
self.is_MLA = False
|
||
|
||
self.past_tokens = []
|
||
self.num_hidden_layers = config.num_hidden_layers
|
||
for idx in range(self.num_hidden_layers):
|
||
# Note: `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
|
||
# breaks when updating the cache.
|
||
if isinstance(device, dict):
|
||
target_device = device[f"blk.{idx}.self_attn"]["generate_device"]
|
||
else:
|
||
target_device = device
|
||
|
||
if self.is_MLA:
|
||
new_layer_key_cache = torch.zeros(latent_shape, dtype=self.dtype, device=target_device)
|
||
new_layer_value_cache = None
|
||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||
else:
|
||
new_layer_key_cache = torch.zeros(key_shape, dtype=self.dtype, device=target_device)
|
||
new_layer_value_cache = torch.zeros(value_shape, dtype=self.dtype, device=target_device)
|
||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||
torch._dynamo.mark_static_address(new_layer_value_cache)
|
||
|
||
self.key_cache.append(new_layer_key_cache)
|
||
self.value_cache.append(new_layer_value_cache)
|
||
self.past_tokens.append(0)
|
||
|
||
@property
|
||
def max_batch_size(self):
|
||
return self._max_batch_size
|
||
|
||
@property
|
||
def max_cache_len(self):
|
||
return self._max_cache_len
|
||
|
||
def update(
|
||
self,
|
||
key_states: torch.Tensor,
|
||
value_states: torch.Tensor,
|
||
layer_idx: int,
|
||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
|
||
|
||
Parameters:
|
||
key_states (`torch.Tensor`):
|
||
The new key states to cache.
|
||
value_states (`torch.Tensor`):
|
||
The new value states to cache.
|
||
layer_idx (`int`):
|
||
The index of the layer to cache the states for.
|
||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
|
||
to know how where to write in the cache.
|
||
|
||
Return:
|
||
A tuple containing the updated key and value states.
|
||
"""
|
||
cache_position = cache_kwargs.get("cache_position")
|
||
k_out = self.key_cache[layer_idx]
|
||
v_out = self.value_cache[layer_idx]
|
||
self.past_tokens[layer_idx] += cache_position.size(0)
|
||
#print(cache_position)
|
||
if self.is_MLA:
|
||
if use_torch_npu:
|
||
page_idx = cache_position // self.page_size_tensor
|
||
page_offset = cache_position % self.page_size_tensor
|
||
|
||
page_idx = page_idx.unsqueeze(0).expand(self.max_batch_size, -1)
|
||
page_offset = page_offset.unsqueeze(0).expand(self.max_batch_size, -1)
|
||
|
||
page_idx_offset = torch.arange(self.max_batch_size, device=page_idx.device) * self.max_pages_per_batch
|
||
page_idx = page_idx + page_idx_offset.unsqueeze(1)
|
||
|
||
combined = torch.cat([key_states, value_states], dim=-1)
|
||
combined = combined.contiguous()
|
||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||
k_out[page_idx, page_offset] = combined
|
||
else:
|
||
page_idx = cache_position // self.page_size
|
||
page_offset = cache_position % self.page_size
|
||
# key shape (self.max_pages, self.page_size, 1, config.kv_lora_rank + config.qk_rope_head_dim)
|
||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states
|
||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states
|
||
return k_out, self.page_table_list[layer_idx]
|
||
else:
|
||
k_out[:, :, cache_position] = key_states
|
||
v_out[:, :, cache_position] = value_states
|
||
return k_out, v_out
|
||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||
"""Returns the sequence length of the cached states that were seen by the model."""
|
||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
||
# limit the check to the first batch member and head dimension.
|
||
# TODO: deprecate this function in favor of `cache_position`
|
||
return self.past_tokens[layer_idx]
|
||
|
||
def change_seq_length(self, bias: Optional[int] = 0) -> int:
|
||
"""Returns the sequence length of the cached states that were seen by the model."""
|
||
# Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
|
||
# limit the check to the first batch member and head dimension.
|
||
# TODO: deprecate this function in favor of `cache_position`
|
||
for layer_idx in range(self.num_hidden_layers):
|
||
self.past_tokens[layer_idx] += bias
|
||
|
||
def get_max_length(self) -> Optional[int]:
|
||
"""Returns the maximum sequence length of the cached states."""
|
||
return self.max_cache_len
|
||
|
||
def get_usable_length(self, kv_seq_len, layer_idx: Optional[int] = 0) -> int:
|
||
return 0
|
||
|
||
def reset(self):
|
||
"""Resets the cache values while preserving the objects"""
|
||
for layer_idx in range(len(self.key_cache)):
|
||
# In-place ops prevent breaking the static address
|
||
self.key_cache[layer_idx].zero_()
|
||
if self.value_cache[layer_idx] is not None:
|
||
self.value_cache[layer_idx].zero_()
|
||
self.past_tokens[layer_idx] = 0
|
||
|
||
if use_torch_npu:
|
||
self.position = [0]
|
||
|
||
def remove_suffix(self, start_pos):
|
||
for layer_idx in range(len(self.key_cache)):
|
||
# In-place ops prevent breaking the static address
|
||
if self.is_MLA:
|
||
k_cache = self.key_cache[layer_idx]
|
||
k_cache.view(-1, k_cache.shape[-1])[start_pos:].zero_()
|
||
else:
|
||
self.key_cache[layer_idx][..., start_pos:, :].zero_()
|
||
self.value_cache[layer_idx][..., start_pos:, :].zero_()
|
||
self.past_tokens[layer_idx] = start_pos
|
||
|
||
def get_max_cache_shape(self) -> Tuple[int, int, int, int]:
|
||
"""Returns the maximum shape of the cache."""
|
||
return self.max_cache_len
|
||
|
||
class KVC2StaticCache:
|
||
"""
|
||
Static Cache class connect with KVC2
|
||
remind: page_idx & page_offset info need to refs to forward batching, only contains KV Block Tensor here
|
||
"""
|
||
def __init__(self, config: PretrainedConfig, max_batch_size, page_size: int = 256, dtype=torch.bfloat16, device=None) -> None:
|
||
super().__init__()
|
||
self.config = config
|
||
self.dtype = dtype
|
||
self.device = torch.device("npu:0")
|
||
self.kv_lora_rank = config.kv_lora_rank
|
||
self.max_batch_size = max_batch_size
|
||
self.page_size = page_size
|
||
self.k_caches = []
|
||
self.v_caches = []
|
||
|
||
self.num_hidden_layers = config.num_hidden_layers
|
||
|
||
self.is_MLA = True if config.architectures[0] in ["DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"] else False
|
||
# kv cache stored in kvc2
|
||
# self.past_tokens = []
|
||
|
||
def load(self, inference_context):
|
||
# assert self.is_MLA and len(inference_context.k_cache) == 1, "currently only support MLA and Cache Pool TP=1"
|
||
from ktransformers.util.utils import get_current_device
|
||
for i in range(self.config.num_hidden_layers):
|
||
new_layer_key_cache = inference_context.k_cache[int(torch.distributed.get_rank())][i].to(get_current_device())
|
||
torch._dynamo.mark_static_address(new_layer_key_cache)
|
||
|
||
self.k_caches.append(
|
||
new_layer_key_cache # [TP_idx, layer_idx, page_idx, page_size, kv_head_num, kv_head_dim]
|
||
)
|
||
|
||
self.v_caches.append(None)
|
||
self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1] # page_len * page_size
|
||
|
||
def update(
|
||
self,
|
||
combined: torch.Tensor,
|
||
layer_idx: int,
|
||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||
|
||
Parameters:
|
||
key_states (`torch.Tensor`):
|
||
The new key states to cache.
|
||
value_states (`torch.Tensor`):
|
||
The new value states to cache.
|
||
layer_idx (`int`):
|
||
The index of the layer to cache the states for.
|
||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||
must have page_idx (`torch.Tensor`): & page_offset (`torch.Tensor`) & cache_position (`torch.Tensor`)
|
||
|
||
Return:
|
||
A tuple containing the updated key and value states.
|
||
"""
|
||
page_idx, page_offset = cache_kwargs.get("page_idx"), cache_kwargs.get("page_offset")
|
||
if page_idx is None or page_offset is None:
|
||
raise ValueError('[ERROR] block info:page_idx & page_offset missing!')
|
||
|
||
k_out = self.k_caches[layer_idx]
|
||
assert self.is_MLA, "currently only support DeepSeekV3 on NPU balance server"
|
||
|
||
if page_idx.dim() == 1:
|
||
page_idx_tmp = page_idx.unsqueeze(0)
|
||
page_offset_tmp = page_offset.unsqueeze(0)
|
||
else:
|
||
page_idx_tmp = page_idx
|
||
page_offset_tmp = page_offset
|
||
|
||
k_out[page_idx_tmp, page_offset_tmp] = combined
|
||
return k_out, page_idx
|
||
|
||
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
|
||
"""Returns the sequence length of the cached states that were seen by the model."""
|
||
raise ValueError('kvc2 cache pool no longer hold seq_length info, refer to forward batching')
|
||
|
||
def get_usable_length(self, kv_seq_len, layer_idx: Optional[int] = 0) -> int:
|
||
return 0
|
||
|
||
def change_seq_length(self, bias: Optional[int] = 0) -> int:
|
||
"""Returns the sequence length of the cached states that were seen by the model."""
|
||
raise ValueError('kvc2 cache pool no longer hold seq_length info, refer to forward batching')
|
||
|
||
def get_max_length(self) -> Optional[int]:
|
||
"""Returns the maximum sequence length of the cached states."""
|
||
return self.max_cache_len
|
||
|
||
def reset(self, inference_context):
|
||
assert self.is_MLA and len(inference_context.k_cache) == 1, "currently only support MLA and Cache Pool TP=1"
|
||
self.k_caches = []
|
||
self.v_caches = []
|
||
for i in range(self.config.num_hidden_layers):
|
||
self.k_caches.append(
|
||
inference_context.k_cache[0][i]
|
||
)
|
||
self.v_caches.append(None)
|
||
self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1] # page_len * page_size
|
||
|
||
def get_page_table(self, mini_batch, bsz_tensors: torch.tensor = None, is_prefill=True):
|
||
if is_prefill:
|
||
# TODO add padding support
|
||
q_lens = [mini_batch.p_q_len[idx] for idx in range(mini_batch.prefill_batch)]
|
||
page_local_idx = -1 * torch.ones(mini_batch.prefill_batch, max(q_lens),
|
||
dtype=mini_batch.p_position_ids.dtype, device=mini_batch.p_position_ids.device)
|
||
page_offset = -1 * torch.ones_like(page_local_idx)
|
||
# convert merged into batched
|
||
start_ids = 0
|
||
for i in range(mini_batch.prefill_batch):
|
||
page_offset[i, 0:q_lens[i]] = mini_batch.p_position_ids[start_ids:start_ids+q_lens[i]] % self.page_size
|
||
page_local_idx[i, 0:q_lens[i]] = mini_batch.p_position_ids[start_ids:start_ids+q_lens[i]] // self.page_size
|
||
for j in range(q_lens[i]):
|
||
# get global page idx index by local page idx from block table, as followed decode
|
||
page_local_idx[i, j] = mini_batch.p_block_tables[i, page_local_idx[i, j]]
|
||
start_ids += q_lens[i]
|
||
page_idx = page_local_idx
|
||
# only padding will cause page_local_idx/page_offset still have -1 value
|
||
# you can use following code as check
|
||
# indices = torch.where(page_offset == -1)
|
||
# assert not indices[0].numel() > 0, 'there still have un-calculated page_idx value'
|
||
else:
|
||
page_local_idx = mini_batch.d_position_ids // self.page_size
|
||
|
||
page_offset = mini_batch.d_position_ids % self.page_size
|
||
|
||
for i in range(mini_batch.decode_batch):
|
||
page_local_idx[i] = mini_batch.d_block_tables[i, page_local_idx[i]]
|
||
|
||
page_idx = page_local_idx
|
||
|
||
return page_idx, page_offset
|
||
|
||
class KDeepSeekV3Cache(nn.Module):
|
||
def __init__(
|
||
self,
|
||
config: PretrainedConfig,
|
||
page_size: int = 256,
|
||
dtype=torch.bfloat16,
|
||
device=torch.device("cuda:0"),
|
||
|
||
):
|
||
super().__init__()
|
||
self.config = config
|
||
self.dtype = dtype
|
||
self.device = device
|
||
self.kv_lora_rank = config.kv_lora_rank
|
||
self.page_size = page_size
|
||
self.k_caches = []
|
||
self.v_caches = []
|
||
|
||
|
||
def load(self, inference_context: "sched_ext.InferenceContext"):
|
||
|
||
for i in range(self.config.num_hidden_layers):
|
||
self.k_caches.append(
|
||
inference_context.k_cache[0][i]
|
||
)
|
||
self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]
|
||
|
||
def update(
|
||
self,
|
||
key_states: torch.Tensor,
|
||
value_states: torch.Tensor,
|
||
layer_idx: int,
|
||
|
||
page_idx: torch.Tensor,
|
||
page_offset: torch.Tensor,
|
||
|
||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
|
||
It is VERY important to index using a tensor, otherwise you introduce a copy to the device.
|
||
|
||
Parameters:
|
||
key_states (`torch.Tensor`):
|
||
The new key states to cache.
|
||
value_states (`torch.Tensor`):
|
||
The new value states to cache.
|
||
layer_idx (`int`):
|
||
The index of the layer to cache the states for.
|
||
cache_kwargs (`Dict[str, Any]`, `optional`):
|
||
Additional arguments for the cache subclass. The `StaticCache` needs the `cache_position` input
|
||
to know how where to write in the cache.
|
||
|
||
Return:
|
||
A tuple containing the updated key and value states.
|
||
"""
|
||
k_out = self.k_caches[layer_idx]
|
||
|
||
k_out[page_idx, page_offset, :, :self.kv_lora_rank] = key_states.reshape(-1, *key_states.shape[2:])
|
||
k_out[page_idx, page_offset, :, self.kv_lora_rank:] = value_states.reshape(-1, *value_states.shape[2:])
|
||
return k_out
|
||
|
||
|
||
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
|
||
page_offset = cache_position % self.page_size
|
||
page_idx_local = cache_position // self.page_size
|
||
query_ids = torch.zeros_like(cache_position)
|
||
for i in range(len(q_indptr) - 1):
|
||
start_idx = q_indptr[i]
|
||
end_idx = q_indptr[i + 1]
|
||
query_ids[start_idx:end_idx] = i
|
||
page_idx = torch.zeros_like(page_idx_local)
|
||
for i in range(bsz_tensors[0]):
|
||
query_id = query_ids[i]
|
||
local_block = page_idx_local[i]
|
||
start_block = kv_indptr[query_id]
|
||
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
|
||
page_idx[i] = kv_indices[start_block + local_block]
|
||
|
||
return page_idx, page_offset
|
||
|
||
class KGQACache(nn.Module):
|
||
def __init__(
|
||
self,
|
||
config: PretrainedConfig,
|
||
page_size: int = 256,
|
||
dtype=torch.bfloat16,
|
||
device=torch.device("cuda:0"),
|
||
|
||
):
|
||
super().__init__()
|
||
self.config = config
|
||
self.dtype = dtype
|
||
self.device = device
|
||
self.page_size = page_size
|
||
self.k_caches = []
|
||
self.v_caches = []
|
||
|
||
|
||
def load(self, inference_context: "sched_ext.InferenceContext"):
|
||
print(self.config.num_hidden_layers)
|
||
for i in range(self.config.num_hidden_layers):
|
||
self.k_caches.append(
|
||
inference_context.k_cache[0][i]
|
||
)
|
||
self.v_caches.append(
|
||
inference_context.v_cache[0][i]
|
||
)
|
||
|
||
|
||
self.max_cache_len = self.k_caches[0].shape[0]*self.k_caches[0].shape[1]
|
||
|
||
|
||
|
||
def get_page_table(self, cache_position: torch.Tensor, q_indptr: torch.Tensor, kv_indptr: torch.Tensor, kv_indices: torch.Tensor, bsz_tensors: torch.tensor):
|
||
page_offset = cache_position % self.page_size
|
||
page_idx_local = cache_position // self.page_size
|
||
query_ids = torch.zeros_like(cache_position)
|
||
for i in range(len(q_indptr) - 1):
|
||
start_idx = q_indptr[i]
|
||
end_idx = q_indptr[i + 1]
|
||
query_ids[start_idx:end_idx] = i
|
||
page_idx = torch.zeros_like(page_idx_local)
|
||
for i in range(bsz_tensors[0]):
|
||
query_id = query_ids[i]
|
||
local_block = page_idx_local[i]
|
||
start_block = kv_indptr[query_id]
|
||
if local_block < kv_indptr[query_id + 1] - kv_indptr[query_id]:
|
||
page_idx[i] = kv_indices[start_block + local_block]
|
||
|
||
return page_idx, page_offset
|
||
|
||
def get_k_cache(self, layer_idx):
|
||
return self.k_caches[layer_idx]
|
||
|
||
def get_v_cache(self, layer_idx):
|
||
return self.v_caches[layer_idx]
|
||
|
||
|
||
class KVC2Qwen3Cache(nn.Module):
|
||
|
||
def __init__(self, config, max_batch_size, page_size=256,
|
||
dtype=torch.bfloat16, device=None):
|
||
super().__init__()
|
||
self.config = config
|
||
self.max_batch_size = max_batch_size
|
||
self.page_size = page_size
|
||
self.dtype = dtype
|
||
self.device = device if device else torch.device("npu:0")
|
||
|
||
self.num_layers = config.num_hidden_layers
|
||
self.num_kv_heads = config.num_key_value_heads
|
||
self.head_dim = config.head_dim
|
||
|
||
self.k_caches = []
|
||
self.v_caches = []
|
||
|
||
|
||
# ------------------------- 绑定到底层 kvc2 pool -------------------------
|
||
|
||
def load(self, inference_context):
|
||
from ktransformers.util.utils import get_current_device
|
||
dev = get_current_device()
|
||
|
||
self.k_caches = []
|
||
self.v_caches = []
|
||
|
||
rank = (
|
||
torch.distributed.get_rank()
|
||
if (torch.distributed.is_available() and torch.distributed.is_initialized())
|
||
else 0
|
||
)
|
||
|
||
for i in range(self.num_layers):
|
||
k_buf = inference_context.k_cache[rank][i].to(dev).to(self.dtype)
|
||
v_buf = inference_context.v_cache[rank][i].to(dev).to(self.dtype)
|
||
|
||
torch._dynamo.mark_static_address(k_buf)
|
||
torch._dynamo.mark_static_address(v_buf)
|
||
|
||
self.k_caches.append(k_buf)
|
||
self.v_caches.append(v_buf)
|
||
|
||
# num_pages * page_size
|
||
self.max_cache_len = self.k_caches[0].shape[0] * self.k_caches[0].shape[1]
|
||
|
||
# ------------------------- 写 KV -------------------------
|
||
@torch.no_grad()
|
||
def update(
|
||
self,
|
||
key_states: torch.Tensor,
|
||
value_states: torch.Tensor,
|
||
layer_idx: int,
|
||
cache_kwargs: Optional[Dict[str, Any]] = None,
|
||
):
|
||
if cache_kwargs is None:
|
||
raise ValueError("[KVC2Qwen3Cache] cache_kwargs must contain page_idx & page_offset")
|
||
|
||
page_idx: Optional[torch.Tensor] = cache_kwargs.get("page_idx", None)
|
||
page_offset: Optional[torch.Tensor] = cache_kwargs.get("page_offset", None)
|
||
|
||
if page_idx is None or page_offset is None:
|
||
raise ValueError("[KVC2Qwen3Cache] page_idx & page_offset are required in cache_kwargs")
|
||
|
||
k_out = self.k_caches[layer_idx]
|
||
v_out = self.v_caches[layer_idx]
|
||
|
||
# -------- 1) 修正维度顺序:[B, KvH, Q, D] -> [B, Q, KvH, D] --------
|
||
if key_states.dim() == 4 and key_states.shape[1] == self.num_kv_heads:
|
||
key_states = key_states.transpose(1, 2).contiguous()
|
||
value_states = value_states.transpose(1, 2).contiguous()
|
||
|
||
if key_states.shape != value_states.shape:
|
||
raise ValueError(
|
||
f"[KVC2Qwen3Cache] key_states.shape {key_states.shape} "
|
||
f"!= value_states.shape {value_states.shape}"
|
||
)
|
||
|
||
if key_states.dim() != 4:
|
||
raise ValueError(
|
||
f"[KVC2Qwen3Cache] expect key_states dim=4, got {key_states.dim()} "
|
||
f"(shape={key_states.shape})"
|
||
)
|
||
|
||
bsz, q_len, kv_heads, head_dim = key_states.shape
|
||
|
||
if kv_heads != self.num_kv_heads or head_dim != self.head_dim:
|
||
raise ValueError(
|
||
f"[KVC2Qwen3Cache] KV shape mismatch: "
|
||
f"got num_kv_heads={kv_heads}, head_dim={head_dim}, "
|
||
f"expected num_kv_heads={self.num_kv_heads}, head_dim={self.head_dim}"
|
||
)
|
||
|
||
# -------- 2) flatten page_idx / page_offset 为一维 --------
|
||
page_idx = page_idx.reshape(-1)
|
||
page_offset = page_offset.reshape(-1)
|
||
|
||
# -------- 3) flatten KV,并强制 dtype 与 cache 对齐 --------
|
||
val_dtype = k_out.dtype
|
||
flat_k = key_states.to(val_dtype).reshape(-1, kv_heads, head_dim)
|
||
flat_v = value_states.to(val_dtype).reshape(-1, kv_heads, head_dim)
|
||
|
||
# -------- 4) 真正写入 K / V --------
|
||
# k_out / v_out: [num_pages, page_size, num_kv_heads, head_dim]
|
||
k_out[page_idx, page_offset] = flat_k
|
||
v_out[page_idx, page_offset] = flat_v
|
||
|
||
# ------------------------- get K/V -------------------------
|
||
def get_k_cache(self, layer_idx):
|
||
return self.k_caches[layer_idx]
|
||
|
||
def get_v_cache(self, layer_idx):
|
||
return self.v_caches[layer_idx]
|
||
|
||
# ------------------------- page table 计算 -------------------------
|
||
def get_page_table(
|
||
self,
|
||
mini_batch,
|
||
bsz_tensors: torch.Tensor = None,
|
||
is_prefill: bool = True,
|
||
):
|
||
if is_prefill:
|
||
# prefill: merged positions => batched (B, T_chunk)
|
||
q_lens = [int(mini_batch.p_q_len[idx]) for idx in range(mini_batch.prefill_batch)]
|
||
if len(q_lens) == 0:
|
||
return None, None
|
||
|
||
max_q_len = max(q_lens)
|
||
|
||
page_local_idx = -1 * torch.ones(
|
||
mini_batch.prefill_batch,
|
||
max_q_len,
|
||
dtype=mini_batch.p_position_ids.dtype,
|
||
device=mini_batch.p_position_ids.device,
|
||
)
|
||
page_offset = -1 * torch.ones_like(page_local_idx)
|
||
|
||
start_ids = 0
|
||
for i in range(mini_batch.prefill_batch):
|
||
cur_len = q_lens[i]
|
||
pos = mini_batch.p_position_ids[start_ids:start_ids + cur_len] # global pos of this chunk
|
||
|
||
# local block + offset by page_size
|
||
page_offset[i, 0:cur_len] = pos % self.page_size
|
||
page_local_idx[i, 0:cur_len] = pos // self.page_size
|
||
|
||
# local block -> global page id via block_tables
|
||
for j in range(cur_len):
|
||
blk = page_local_idx[i, j]
|
||
page_local_idx[i, j] = mini_batch.p_block_tables[i, blk]
|
||
|
||
start_ids += cur_len
|
||
|
||
page_idx = page_local_idx
|
||
else:
|
||
# decode: decode_batch = 当前 step 的 batch_size, 每条样本通常 1 个 token
|
||
page_local_idx = mini_batch.d_position_ids // self.page_size
|
||
page_offset = mini_batch.d_position_ids % self.page_size
|
||
|
||
for i in range(mini_batch.decode_batch):
|
||
blk = page_local_idx[i]
|
||
page_local_idx[i] = mini_batch.d_block_tables[i, blk]
|
||
|
||
page_idx = page_local_idx
|
||
|
||
return page_idx, page_offset
|