mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
BlockSparseMLP: Keep buffers between experts when possible
This commit is contained in:
@@ -626,6 +626,11 @@ class BlockSparseMLP(Module):
|
||||
expert_ptr[1:] = expert_count.cumsum(0)
|
||||
expert_ptr = expert_ptr.tolist()
|
||||
|
||||
out_state = None
|
||||
interm = None
|
||||
interm_a = None
|
||||
max_count = 0
|
||||
|
||||
for expert_idx in range(num_ex):
|
||||
start = expert_ptr[expert_idx]
|
||||
end = expert_ptr[expert_idx + 1]
|
||||
@@ -638,26 +643,40 @@ class BlockSparseMLP(Module):
|
||||
|
||||
current_state = y.index_select(0, top_x)
|
||||
|
||||
if self.bc is not None and False:
|
||||
if self.bc is not None:
|
||||
if count <= TEMP_ROWS:
|
||||
self.bc.run_single_expert(current_state, expert_idx)
|
||||
current_state = self.experts_cfg.out_d2[:count]
|
||||
else:
|
||||
out_state = torch.empty(
|
||||
(count, self.hidden_size),
|
||||
dtype = self.out_dtype or torch.half,
|
||||
device = self.device
|
||||
)
|
||||
interm = torch.empty(
|
||||
(count * 2, self.intermediate_size),
|
||||
dtype = self.interm_dtype,
|
||||
device = self.device
|
||||
)
|
||||
interm_a = interm[:count] if self.interm_dtype == torch.half else \
|
||||
torch.empty_like(interm[:count], dtype = torch.half)
|
||||
if count > max_count:
|
||||
out_state = torch.empty(
|
||||
(count, self.hidden_size),
|
||||
dtype = self.out_dtype or torch.half,
|
||||
device = self.device
|
||||
)
|
||||
interm = torch.empty(
|
||||
(count * 2, self.intermediate_size),
|
||||
dtype = self.interm_dtype,
|
||||
device = self.device
|
||||
)
|
||||
interm_a = interm[:count] if self.interm_dtype == torch.half else \
|
||||
torch.empty_like(interm[:count], dtype = torch.half)
|
||||
out_state_ = out_state
|
||||
interm_ = interm
|
||||
interm_a_ = interm_a
|
||||
max_count = count
|
||||
elif count == max_count:
|
||||
out_state_ = out_state
|
||||
interm_ = interm
|
||||
interm_a_ = interm_a
|
||||
else:
|
||||
out_state_ = out_state[:count]
|
||||
interm_ = interm[:count * 2]
|
||||
interm_a_ = interm_a[:count]
|
||||
|
||||
yh = torch.empty((count * 2, self.hidden_size), dtype = torch.half, device = self.device)
|
||||
self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state)
|
||||
current_state = out_state
|
||||
self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm_, interm_a_, out_state)
|
||||
current_state = out_state_
|
||||
else:
|
||||
def mlp(exp_i, xc):
|
||||
g = self.gates[exp_i].forward(xc, params)
|
||||
|
||||
Reference in New Issue
Block a user