QMLP: Skip .view

This commit is contained in:
turboderp
2024-06-16 19:14:47 +02:00
parent 22d6823f98
commit 522cab53fa
2 changed files with 10 additions and 5 deletions

View File

@@ -84,15 +84,18 @@ void q_mlp_forward_
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
TORCH_CHECK(x.size(1) == mlp->up->height, "x is wrong shape");
TORCH_CHECK(x.size(0) <= mlp->max_rows, "Too many rows in x");
int dim = x.size(-1);
int rows = x.numel() / dim;
TORCH_CHECK(dim == mlp->up->height, "x is wrong shape");
TORCH_CHECK(rows <= mlp->max_rows, "Too many rows in x");
mlp->forward_
(
at::cuda::getCurrentCUDABlasHandle(),
(half*) x.data_ptr(),
x.size(0), // rows
x.size(1), // columns == hidden_size
rows,
dim,
loras,
loras_temp.device().is_meta() ? NULL : (half*) loras_temp.data_ptr()
);

View File

@@ -8,6 +8,7 @@ from exllamav2.layernorm import ExLlamaV2LayerNorm
from exllamav2.linear import ExLlamaV2Linear
from exllamav2.ext import exllamav2_ext as ext_c, none_tensor
from exllamav2.lora import ExLlamaV2Lora
# from line_profiler import profile
from typing import TYPE_CHECKING
if TYPE_CHECKING:
@@ -225,6 +226,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
self.down_proj.set_device_idx(idx)
# @profile
def forward(self,
hidden_states: torch.Tensor,
cache = None,
@@ -245,7 +247,7 @@ class ExLlamaV2MLP(ExLlamaV2Module):
pass_lora_temp = torch.empty((self.temp_lora_size,), dtype = torch.half, device = hidden_states.device)
ext_c.q_mlp_forward_(self.q_handle,
hidden_states.view(-1, hidden_states.shape[-1]),
hidden_states,
pass_loras,
pass_lora_temp)