mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
Non-blocking host-device copies in forward pass
This commit is contained in:
@@ -15,6 +15,7 @@ from exllamav2.architecture import RopeStyle
|
||||
import math
|
||||
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
|
||||
import torch.nn.functional as F
|
||||
# from line_profiler import profile
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@@ -255,14 +256,14 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
self.block_index = safe_move_tensor(self.block_index, device)
|
||||
return self.block_index
|
||||
|
||||
def get_cache_seqlens(self, device) -> torch.Tensor:
|
||||
if self.cache_seqlens.device != device:
|
||||
self.cache_seqlens = safe_move_tensor(self.cache_seqlens, device)
|
||||
def get_cache_seqlens(self, device_idx: int) -> torch.Tensor:
|
||||
if self.cache_seqlens.device.index != device_idx:
|
||||
self.cache_seqlens = safe_move_tensor(self.cache_seqlens, device_idx, non_blocking = True)
|
||||
return self.cache_seqlens
|
||||
|
||||
def get_cache_seqlens_after(self, device) -> torch.Tensor:
|
||||
if self.cache_seqlens_after.device != device:
|
||||
self.cache_seqlens_after = safe_move_tensor(self.cache_seqlens_after, device)
|
||||
def get_cache_seqlens_after(self, device_idx: int) -> torch.Tensor:
|
||||
if self.cache_seqlens_after.device.index != device_idx:
|
||||
self.cache_seqlens_after = safe_move_tensor(self.cache_seqlens_after, device_idx, non_blocking = True)
|
||||
return self.cache_seqlens_after
|
||||
|
||||
|
||||
@@ -546,6 +547,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# @profile
|
||||
def forward_paged(self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache: ExLlamaV2CacheBase | None = None,
|
||||
@@ -558,8 +560,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
constants = self.model.get_device_tensors(self.device_idx, scratch = is_q)
|
||||
page_size = attn_params.page_size
|
||||
batch_size, q_len, _ = hidden_states.shape
|
||||
cache_seqlens = attn_params.get_cache_seqlens(self.device())
|
||||
block_table = attn_params.get_block_index(self.device())
|
||||
cache_seqlens = attn_params.get_cache_seqlens(self.device_idx)
|
||||
block_table = attn_params.get_block_index(self.device_idx)
|
||||
|
||||
# TODO: We only need keys/values when preprocess_only == True, so we could skip q projection and attention
|
||||
# on the last layer. Would need custom kernel to update paged cache if not calling flash_attn_with_kvcache
|
||||
@@ -599,7 +601,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
batch_size,
|
||||
q_len,
|
||||
0,
|
||||
attn_params.get_cache_seqlens(self.device()),
|
||||
cache_seqlens,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
@@ -629,7 +631,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
0,
|
||||
heads,
|
||||
cfg.head_dim,
|
||||
attn_params.get_cache_seqlens(self.device()),
|
||||
cache_seqlens,
|
||||
cfg.arch.rope_style == RopeStyle.NEOX
|
||||
)
|
||||
if attn_params.is_sequential:
|
||||
@@ -641,7 +643,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
if attn_params.is_sequential:
|
||||
k = None
|
||||
v = None
|
||||
cache_seqlens_a = attn_params.get_cache_seqlens_after(self.device())
|
||||
cache_seqlens_a = attn_params.get_cache_seqlens_after(self.device_idx)
|
||||
else:
|
||||
cache_seqlens_a = cache_seqlens
|
||||
|
||||
@@ -759,6 +761,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
|
||||
return attn_output
|
||||
|
||||
|
||||
# @profile
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache: ExLlamaV2CacheBase | None = None,
|
||||
|
||||
@@ -37,14 +37,15 @@ def test_gpu_peer_copy(device_a: torch.Device,
|
||||
|
||||
|
||||
def safe_move_tensor(tensor: torch.Tensor | tuple[torch.Tensor],
|
||||
device: torch.Device | str):
|
||||
device: torch.Device | str | int,
|
||||
non_blocking = False):
|
||||
|
||||
# Accept tensor or tuple of tensors
|
||||
|
||||
if isinstance(tensor, tuple):
|
||||
return tuple(safe_move_tensor(x, device) for x in tensor)
|
||||
|
||||
# Accept torch.device or string
|
||||
# Accept torch.device, string or int
|
||||
|
||||
device = torch.device(device)
|
||||
|
||||
@@ -56,13 +57,13 @@ def safe_move_tensor(tensor: torch.Tensor | tuple[torch.Tensor],
|
||||
# Copies to/from system RAM are always fine
|
||||
|
||||
if tensor.device.type == "cpu" or device.type == "cpu":
|
||||
return tensor.to(device)
|
||||
return tensor.to(device, non_blocking = non_blocking)
|
||||
|
||||
# Source and dest are distinct CUDA devices
|
||||
# Test tensor.to (once) and if it seems to be working, let Torch decide
|
||||
|
||||
if test_gpu_peer_copy(tensor.device, device):
|
||||
return tensor.to(device)
|
||||
return tensor.to(device, non_blocking = non_blocking)
|
||||
|
||||
# Force move tensor via CPU
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ import threading
|
||||
from typing import Callable
|
||||
# from exllamav2.util import list_live_tensors, print_vram_usage, set_snapshot, diff_snapshot, print_vram_usage_peak
|
||||
from exllamav2.util import get_basic_progress
|
||||
# from line_profiler import profile
|
||||
|
||||
|
||||
def _torch_device(idx):
|
||||
@@ -818,6 +819,7 @@ class ExLlamaV2:
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
# @profile
|
||||
def forward_chunk(self,
|
||||
input_ids: torch.Tensor,
|
||||
cache: ExLlamaV2CacheBase | list[ExLlamaV2CacheBase] | None = None,
|
||||
@@ -867,6 +869,7 @@ class ExLlamaV2:
|
||||
past_len = attn_params.past_len
|
||||
cache.current_seq_len = past_len
|
||||
|
||||
device = self.modules[0].device_idx
|
||||
for idx, module in enumerate(self.modules):
|
||||
|
||||
if idx == self.head_layer_idx and last_id_only:
|
||||
@@ -884,9 +887,11 @@ class ExLlamaV2:
|
||||
|
||||
# Onward
|
||||
|
||||
device = _torch_device(module.device_idx)
|
||||
n_device = _torch_device(module.device_idx)
|
||||
if n_device != device:
|
||||
x = safe_move_tensor(x, n_device, non_blocking = True)
|
||||
device = n_device
|
||||
|
||||
x = safe_move_tensor(x, device)
|
||||
x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras, **kwargs)
|
||||
|
||||
if preprocess_only and idx == self.last_kv_layer_idx:
|
||||
|
||||
Reference in New Issue
Block a user