Non-blocking host-device copies in forward pass

This commit is contained in:
turboderp
2024-06-16 19:18:01 +02:00
parent 522cab53fa
commit 843cec5206
3 changed files with 26 additions and 17 deletions

View File

@@ -15,6 +15,7 @@ from exllamav2.architecture import RopeStyle
import math
# from exllamav2.util import list_live_tensors, set_snapshot, diff_snapshot, print_vram_usage_peak
import torch.nn.functional as F
# from line_profiler import profile
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -255,14 +256,14 @@ class ExLlamaV2Attention(ExLlamaV2Module):
self.block_index = safe_move_tensor(self.block_index, device)
return self.block_index
def get_cache_seqlens(self, device) -> torch.Tensor:
if self.cache_seqlens.device != device:
self.cache_seqlens = safe_move_tensor(self.cache_seqlens, device)
def get_cache_seqlens(self, device_idx: int) -> torch.Tensor:
if self.cache_seqlens.device.index != device_idx:
self.cache_seqlens = safe_move_tensor(self.cache_seqlens, device_idx, non_blocking = True)
return self.cache_seqlens
def get_cache_seqlens_after(self, device) -> torch.Tensor:
if self.cache_seqlens_after.device != device:
self.cache_seqlens_after = safe_move_tensor(self.cache_seqlens_after, device)
def get_cache_seqlens_after(self, device_idx: int) -> torch.Tensor:
if self.cache_seqlens_after.device.index != device_idx:
self.cache_seqlens_after = safe_move_tensor(self.cache_seqlens_after, device_idx, non_blocking = True)
return self.cache_seqlens_after
@@ -546,6 +547,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
return hidden_states
# @profile
def forward_paged(self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,
@@ -558,8 +560,8 @@ class ExLlamaV2Attention(ExLlamaV2Module):
constants = self.model.get_device_tensors(self.device_idx, scratch = is_q)
page_size = attn_params.page_size
batch_size, q_len, _ = hidden_states.shape
cache_seqlens = attn_params.get_cache_seqlens(self.device())
block_table = attn_params.get_block_index(self.device())
cache_seqlens = attn_params.get_cache_seqlens(self.device_idx)
block_table = attn_params.get_block_index(self.device_idx)
# TODO: We only need keys/values when preprocess_only == True, so we could skip q projection and attention
# on the last layer. Would need custom kernel to update paged cache if not calling flash_attn_with_kvcache
@@ -599,7 +601,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
batch_size,
q_len,
0,
attn_params.get_cache_seqlens(self.device()),
cache_seqlens,
q,
k,
v,
@@ -629,7 +631,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
0,
heads,
cfg.head_dim,
attn_params.get_cache_seqlens(self.device()),
cache_seqlens,
cfg.arch.rope_style == RopeStyle.NEOX
)
if attn_params.is_sequential:
@@ -641,7 +643,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
if attn_params.is_sequential:
k = None
v = None
cache_seqlens_a = attn_params.get_cache_seqlens_after(self.device())
cache_seqlens_a = attn_params.get_cache_seqlens_after(self.device_idx)
else:
cache_seqlens_a = cache_seqlens
@@ -759,6 +761,7 @@ class ExLlamaV2Attention(ExLlamaV2Module):
return attn_output
# @profile
def forward(self,
hidden_states: torch.Tensor,
cache: ExLlamaV2CacheBase | None = None,

View File

@@ -37,14 +37,15 @@ def test_gpu_peer_copy(device_a: torch.Device,
def safe_move_tensor(tensor: torch.Tensor | tuple[torch.Tensor],
device: torch.Device | str):
device: torch.Device | str | int,
non_blocking = False):
# Accept tensor or tuple of tensors
if isinstance(tensor, tuple):
return tuple(safe_move_tensor(x, device) for x in tensor)
# Accept torch.device or string
# Accept torch.device, string or int
device = torch.device(device)
@@ -56,13 +57,13 @@ def safe_move_tensor(tensor: torch.Tensor | tuple[torch.Tensor],
# Copies to/from system RAM are always fine
if tensor.device.type == "cpu" or device.type == "cpu":
return tensor.to(device)
return tensor.to(device, non_blocking = non_blocking)
# Source and dest are distinct CUDA devices
# Test tensor.to (once) and if it seems to be working, let Torch decide
if test_gpu_peer_copy(tensor.device, device):
return tensor.to(device)
return tensor.to(device, non_blocking = non_blocking)
# Force move tensor via CPU

View File

@@ -48,6 +48,7 @@ import threading
from typing import Callable
# from exllamav2.util import list_live_tensors, print_vram_usage, set_snapshot, diff_snapshot, print_vram_usage_peak
from exllamav2.util import get_basic_progress
# from line_profiler import profile
def _torch_device(idx):
@@ -818,6 +819,7 @@ class ExLlamaV2:
@torch.inference_mode()
# @profile
def forward_chunk(self,
input_ids: torch.Tensor,
cache: ExLlamaV2CacheBase | list[ExLlamaV2CacheBase] | None = None,
@@ -867,6 +869,7 @@ class ExLlamaV2:
past_len = attn_params.past_len
cache.current_seq_len = past_len
device = self.modules[0].device_idx
for idx, module in enumerate(self.modules):
if idx == self.head_layer_idx and last_id_only:
@@ -884,9 +887,11 @@ class ExLlamaV2:
# Onward
device = _torch_device(module.device_idx)
n_device = _torch_device(module.device_idx)
if n_device != device:
x = safe_move_tensor(x, n_device, non_blocking = True)
device = n_device
x = safe_move_tensor(x, device)
x = module.forward(x, cache = cache, attn_params = attn_params, past_len = past_len, loras = loras, **kwargs)
if preprocess_only and idx == self.last_kv_layer_idx: