mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-04-20 14:29:28 +00:00
QMLP: Skip .view
This commit is contained in:
@@ -84,15 +84,18 @@ void q_mlp_forward_
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
|
||||
TORCH_CHECK(x.size(1) == mlp->up->height, "x is wrong shape");
|
||||
TORCH_CHECK(x.size(0) <= mlp->max_rows, "Too many rows in x");
|
||||
int dim = x.size(-1);
|
||||
int rows = x.numel() / dim;
|
||||
|
||||
TORCH_CHECK(dim == mlp->up->height, "x is wrong shape");
|
||||
TORCH_CHECK(rows <= mlp->max_rows, "Too many rows in x");
|
||||
|
||||
mlp->forward_
|
||||
(
|
||||
at::cuda::getCurrentCUDABlasHandle(),
|
||||
(half*) x.data_ptr(),
|
||||
x.size(0), // rows
|
||||
x.size(1), // columns == hidden_size
|
||||
rows,
|
||||
dim,
|
||||
loras,
|
||||
loras_temp.device().is_meta() ? NULL : (half*) loras_temp.data_ptr()
|
||||
);
|
||||
|
||||
@@ -8,6 +8,7 @@ from exllamav2.layernorm import ExLlamaV2LayerNorm
|
||||
from exllamav2.linear import ExLlamaV2Linear
|
||||
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
|
||||
from exllamav2.lora import ExLlamaV2Lora
|
||||
# from line_profiler import profile
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
if TYPE_CHECKING:
|
||||
@@ -225,6 +226,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
self.down_proj.set_device_idx(idx)
|
||||
|
||||
|
||||
# @profile
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
cache = None,
|
||||
@@ -245,7 +247,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
|
||||
pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device)
|
||||
|
||||
ext_c.q_mlp_forward_(self.q_handle,
|
||||
hidden_states.view(-1, hidden_states.shape[-1]),
|
||||
hidden_states,
|
||||
pass_loras,
|
||||
pass_lora_temp)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user