Remove some moe stuff for finetuning. Drastically reduces vram usage

This commit is contained in:
Jaret Burkett
2025-04-14 00:57:34 +00:00
parent f80cf99f40
commit 3a5ea2c742
2 changed files with 43 additions and 32 deletions

View File

@@ -129,6 +129,14 @@ class HidreamModel(BaseModel):
torch_dtype=torch.bfloat16
)
# count the params
sd = transformer.state_dict()
num_params = sum(p.numel() for p in sd.values())
print(f"Number of params in transformer: {num_params}")
# count params with name expert in them
num_expert_params = sum(p.numel() for k, p in sd.items() if 'expert' in k)
print(f"Number of params in transformer with expert in them: {num_expert_params}")
if not self.low_vram:
transformer.to(self.device_torch, dtype=dtype)

View File

@@ -69,28 +69,29 @@ class MoEGate(nn.Module):
if self.top_k > 1 and self.norm_topk_prob:
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
topk_weight = topk_weight / denominator
# this was in original and memory leaks, not needed
# ### expert-level computation auxiliary loss
# if self.training and self.alpha > 0.0:
# scores_for_aux = scores
# aux_topk = self.top_k
# # always compute aux loss based on the naive greedy topk method
# topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
# if self.seq_aux:
# scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
# ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
# ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
# aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
# else:
# mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
# ce = mask_ce.float().mean(0)
### expert-level computation auxiliary loss
if self.training and self.alpha > 0.0:
scores_for_aux = scores
aux_topk = self.top_k
# always compute aux loss based on the naive greedy topk method
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
if self.seq_aux:
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
else:
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
ce = mask_ce.float().mean(0)
Pi = scores_for_aux.mean(0)
fi = ce * self.n_routed_experts
aux_loss = (Pi * fi).sum() * self.alpha
save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
else:
aux_loss = None
# Pi = scores_for_aux.mean(0)
# fi = ce * self.n_routed_experts
# aux_loss = (Pi * fi).sum() * self.alpha
# save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
# else:
aux_loss = None
return topk_idx, topk_weight, aux_loss
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
@@ -119,20 +120,22 @@ class MOEFeedForwardSwiGLU(nn.Module):
topk_idx, topk_weight, aux_loss = self.gate(x)
x = x.view(-1, x.shape[-1])
flat_topk_idx = topk_idx.view(-1)
if self.training:
x = x.repeat_interleave(self.num_activated_experts, dim=0)
y = torch.empty_like(x, dtype=wtype)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
y = y.view(*orig_shape).to(dtype=wtype)
#y = AddAuxiliaryLoss.apply(y, aux_loss)
else:
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
# this was in original and memory leaks, not needed
# if self.training:
# x = x.repeat_interleave(self.num_activated_experts, dim=0)
# y = torch.empty_like(x, dtype=wtype)
# for i, expert in enumerate(self.experts):
# y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
# y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
# y = y.view(*orig_shape).to(dtype=wtype)
# #y = AddAuxiliaryLoss.apply(y, aux_loss)
# else:
# y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
y = y + self.shared_experts(identity)
return y
@torch.no_grad()
# @torch.no_grad()
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
expert_cache = torch.zeros_like(x)
idxs = flat_expert_indices.argsort()