diff --git a/exllamav3/modules/block_sparse_mlp.py b/exllamav3/modules/block_sparse_mlp.py index 1646c5f..115f34c 100644 --- a/exllamav3/modules/block_sparse_mlp.py +++ b/exllamav3/modules/block_sparse_mlp.py @@ -626,6 +626,11 @@ class BlockSparseMLP(Module): expert_ptr[1:] = expert_count.cumsum(0) expert_ptr = expert_ptr.tolist() + out_state = None + interm = None + interm_a = None + max_count = 0 + for expert_idx in range(num_ex): start = expert_ptr[expert_idx] end = expert_ptr[expert_idx + 1] @@ -638,26 +643,40 @@ class BlockSparseMLP(Module): current_state = y.index_select(0, top_x) - if self.bc is not None and False: + if self.bc is not None: if count <= TEMP_ROWS: self.bc.run_single_expert(current_state, expert_idx) current_state = self.experts_cfg.out_d2[:count] else: - out_state = torch.empty( - (count, self.hidden_size), - dtype = self.out_dtype or torch.half, - device = self.device - ) - interm = torch.empty( - (count * 2, self.intermediate_size), - dtype = self.interm_dtype, - device = self.device - ) - interm_a = interm[:count] if self.interm_dtype == torch.half else \ - torch.empty_like(interm[:count], dtype = torch.half) + if count > max_count: + out_state = torch.empty( + (count, self.hidden_size), + dtype = self.out_dtype or torch.half, + device = self.device + ) + interm = torch.empty( + (count * 2, self.intermediate_size), + dtype = self.interm_dtype, + device = self.device + ) + interm_a = interm[:count] if self.interm_dtype == torch.half else \ + torch.empty_like(interm[:count], dtype = torch.half) + out_state_ = out_state + interm_ = interm + interm_a_ = interm_a + max_count = count + elif count == max_count: + out_state_ = out_state + interm_ = interm + interm_a_ = interm_a + else: + out_state_ = out_state[:count] + interm_ = interm[:count * 2] + interm_a_ = interm_a[:count] + yh = torch.empty((count * 2, self.hidden_size), dtype = torch.half, device = self.device) - self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm, interm_a, out_state) - current_state = out_state + self.bc.run_single_expert_dq(current_state, expert_idx, yh, interm_, interm_a_, out_state) + current_state = out_state_ else: def mlp(exp_i, xc): g = self.gates[exp_i].forward(xc, params)