mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-01-26 16:39:47 +00:00
Remove some moe stuff for finetuning. Drastically reduces vram usage
This commit is contained in:
@@ -129,6 +129,14 @@ class HidreamModel(BaseModel):
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
|
||||
# count the params
|
||||
sd = transformer.state_dict()
|
||||
num_params = sum(p.numel() for p in sd.values())
|
||||
print(f"Number of params in transformer: {num_params}")
|
||||
# count params with name expert in them
|
||||
num_expert_params = sum(p.numel() for k, p in sd.items() if 'expert' in k)
|
||||
print(f"Number of params in transformer with expert in them: {num_expert_params}")
|
||||
|
||||
if not self.low_vram:
|
||||
transformer.to(self.device_torch, dtype=dtype)
|
||||
|
||||
|
||||
@@ -69,28 +69,29 @@ class MoEGate(nn.Module):
|
||||
if self.top_k > 1 and self.norm_topk_prob:
|
||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||
topk_weight = topk_weight / denominator
|
||||
# this was in original and memory leaks, not needed
|
||||
|
||||
# ### expert-level computation auxiliary loss
|
||||
# if self.training and self.alpha > 0.0:
|
||||
# scores_for_aux = scores
|
||||
# aux_topk = self.top_k
|
||||
# # always compute aux loss based on the naive greedy topk method
|
||||
# topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
# if self.seq_aux:
|
||||
# scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
# ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||
# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
|
||||
# else:
|
||||
# mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
# ce = mask_ce.float().mean(0)
|
||||
|
||||
### expert-level computation auxiliary loss
|
||||
if self.training and self.alpha > 0.0:
|
||||
scores_for_aux = scores
|
||||
aux_topk = self.top_k
|
||||
# always compute aux loss based on the naive greedy topk method
|
||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||
if self.seq_aux:
|
||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||
ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||
aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
|
||||
else:
|
||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||
ce = mask_ce.float().mean(0)
|
||||
|
||||
Pi = scores_for_aux.mean(0)
|
||||
fi = ce * self.n_routed_experts
|
||||
aux_loss = (Pi * fi).sum() * self.alpha
|
||||
save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
|
||||
else:
|
||||
aux_loss = None
|
||||
# Pi = scores_for_aux.mean(0)
|
||||
# fi = ce * self.n_routed_experts
|
||||
# aux_loss = (Pi * fi).sum() * self.alpha
|
||||
# save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
|
||||
# else:
|
||||
aux_loss = None
|
||||
return topk_idx, topk_weight, aux_loss
|
||||
|
||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||
@@ -119,20 +120,22 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||
x = x.view(-1, x.shape[-1])
|
||||
flat_topk_idx = topk_idx.view(-1)
|
||||
if self.training:
|
||||
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||
y = torch.empty_like(x, dtype=wtype)
|
||||
for i, expert in enumerate(self.experts):
|
||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
y = y.view(*orig_shape).to(dtype=wtype)
|
||||
#y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
else:
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
# this was in original and memory leaks, not needed
|
||||
# if self.training:
|
||||
# x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||
# y = torch.empty_like(x, dtype=wtype)
|
||||
# for i, expert in enumerate(self.experts):
|
||||
# y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||
# y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||
# y = y.view(*orig_shape).to(dtype=wtype)
|
||||
# #y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||
# else:
|
||||
# y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||
y = y + self.shared_experts(identity)
|
||||
return y
|
||||
|
||||
@torch.no_grad()
|
||||
# @torch.no_grad()
|
||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||
expert_cache = torch.zeros_like(x)
|
||||
idxs = flat_expert_indices.argsort()
|
||||
|
||||
Reference in New Issue
Block a user