mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-04-30 11:11:37 +00:00
Remove some moe stuff for finetuning. Drastically reduces vram usage
This commit is contained in:
@@ -129,6 +129,14 @@ class HidreamModel(BaseModel):
|
|||||||
torch_dtype=torch.bfloat16
|
torch_dtype=torch.bfloat16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# count the params
|
||||||
|
sd = transformer.state_dict()
|
||||||
|
num_params = sum(p.numel() for p in sd.values())
|
||||||
|
print(f"Number of params in transformer: {num_params}")
|
||||||
|
# count params with name expert in them
|
||||||
|
num_expert_params = sum(p.numel() for k, p in sd.items() if 'expert' in k)
|
||||||
|
print(f"Number of params in transformer with expert in them: {num_expert_params}")
|
||||||
|
|
||||||
if not self.low_vram:
|
if not self.low_vram:
|
||||||
transformer.to(self.device_torch, dtype=dtype)
|
transformer.to(self.device_torch, dtype=dtype)
|
||||||
|
|
||||||
|
|||||||
@@ -69,28 +69,29 @@ class MoEGate(nn.Module):
|
|||||||
if self.top_k > 1 and self.norm_topk_prob:
|
if self.top_k > 1 and self.norm_topk_prob:
|
||||||
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
||||||
topk_weight = topk_weight / denominator
|
topk_weight = topk_weight / denominator
|
||||||
|
# this was in original and memory leaks, not needed
|
||||||
|
|
||||||
|
# ### expert-level computation auxiliary loss
|
||||||
|
# if self.training and self.alpha > 0.0:
|
||||||
|
# scores_for_aux = scores
|
||||||
|
# aux_topk = self.top_k
|
||||||
|
# # always compute aux loss based on the naive greedy topk method
|
||||||
|
# topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
||||||
|
# if self.seq_aux:
|
||||||
|
# scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
||||||
|
# ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
||||||
|
# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
|
||||||
|
# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
|
||||||
|
# else:
|
||||||
|
# mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
||||||
|
# ce = mask_ce.float().mean(0)
|
||||||
|
|
||||||
### expert-level computation auxiliary loss
|
# Pi = scores_for_aux.mean(0)
|
||||||
if self.training and self.alpha > 0.0:
|
# fi = ce * self.n_routed_experts
|
||||||
scores_for_aux = scores
|
# aux_loss = (Pi * fi).sum() * self.alpha
|
||||||
aux_topk = self.top_k
|
# save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
|
||||||
# always compute aux loss based on the naive greedy topk method
|
# else:
|
||||||
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
aux_loss = None
|
||||||
if self.seq_aux:
|
|
||||||
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
|
||||||
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
|
||||||
ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
|
|
||||||
aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
|
|
||||||
else:
|
|
||||||
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
|
||||||
ce = mask_ce.float().mean(0)
|
|
||||||
|
|
||||||
Pi = scores_for_aux.mean(0)
|
|
||||||
fi = ce * self.n_routed_experts
|
|
||||||
aux_loss = (Pi * fi).sum() * self.alpha
|
|
||||||
save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
|
|
||||||
else:
|
|
||||||
aux_loss = None
|
|
||||||
return topk_idx, topk_weight, aux_loss
|
return topk_idx, topk_weight, aux_loss
|
||||||
|
|
||||||
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
||||||
@@ -119,20 +120,22 @@ class MOEFeedForwardSwiGLU(nn.Module):
|
|||||||
topk_idx, topk_weight, aux_loss = self.gate(x)
|
topk_idx, topk_weight, aux_loss = self.gate(x)
|
||||||
x = x.view(-1, x.shape[-1])
|
x = x.view(-1, x.shape[-1])
|
||||||
flat_topk_idx = topk_idx.view(-1)
|
flat_topk_idx = topk_idx.view(-1)
|
||||||
if self.training:
|
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||||
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
# this was in original and memory leaks, not needed
|
||||||
y = torch.empty_like(x, dtype=wtype)
|
# if self.training:
|
||||||
for i, expert in enumerate(self.experts):
|
# x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
||||||
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
# y = torch.empty_like(x, dtype=wtype)
|
||||||
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
# for i, expert in enumerate(self.experts):
|
||||||
y = y.view(*orig_shape).to(dtype=wtype)
|
# y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
||||||
#y = AddAuxiliaryLoss.apply(y, aux_loss)
|
# y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
||||||
else:
|
# y = y.view(*orig_shape).to(dtype=wtype)
|
||||||
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
# #y = AddAuxiliaryLoss.apply(y, aux_loss)
|
||||||
|
# else:
|
||||||
|
# y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
||||||
y = y + self.shared_experts(identity)
|
y = y + self.shared_experts(identity)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
@torch.no_grad()
|
# @torch.no_grad()
|
||||||
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
||||||
expert_cache = torch.zeros_like(x)
|
expert_cache = torch.zeros_like(x)
|
||||||
idxs = flat_expert_indices.argsort()
|
idxs = flat_expert_indices.argsort()
|
||||||
|
|||||||
Reference in New Issue
Block a user