diff --git a/extensions_built_in/diffusion_models/hidream/hidream_model.py b/extensions_built_in/diffusion_models/hidream/hidream_model.py index 3322a111..eb02c644 100644 --- a/extensions_built_in/diffusion_models/hidream/hidream_model.py +++ b/extensions_built_in/diffusion_models/hidream/hidream_model.py @@ -129,6 +129,14 @@ class HidreamModel(BaseModel): torch_dtype=torch.bfloat16 ) + # count the params + sd = transformer.state_dict() + num_params = sum(p.numel() for p in sd.values()) + print(f"Number of params in transformer: {num_params}") + # count params with name expert in them + num_expert_params = sum(p.numel() for k, p in sd.items() if 'expert' in k) + print(f"Number of params in transformer with expert in them: {num_expert_params}") + if not self.low_vram: transformer.to(self.device_torch, dtype=dtype) diff --git a/extensions_built_in/diffusion_models/hidream/src/models/moe.py b/extensions_built_in/diffusion_models/hidream/src/models/moe.py index 1745dcbe..3b3b6ce2 100644 --- a/extensions_built_in/diffusion_models/hidream/src/models/moe.py +++ b/extensions_built_in/diffusion_models/hidream/src/models/moe.py @@ -69,28 +69,29 @@ class MoEGate(nn.Module): if self.top_k > 1 and self.norm_topk_prob: denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20 topk_weight = topk_weight / denominator + # this was in original and memory leaks, not needed + + # ### expert-level computation auxiliary loss + # if self.training and self.alpha > 0.0: + # scores_for_aux = scores + # aux_topk = self.top_k + # # always compute aux loss based on the naive greedy topk method + # topk_idx_for_aux_loss = topk_idx.view(bsz, -1) + # if self.seq_aux: + # scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) + # ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) + # ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts) + # aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha + # else: + # mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) + # ce = mask_ce.float().mean(0) - ### expert-level computation auxiliary loss - if self.training and self.alpha > 0.0: - scores_for_aux = scores - aux_topk = self.top_k - # always compute aux loss based on the naive greedy topk method - topk_idx_for_aux_loss = topk_idx.view(bsz, -1) - if self.seq_aux: - scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1) - ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device) - ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts) - aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha - else: - mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts) - ce = mask_ce.float().mean(0) - - Pi = scores_for_aux.mean(0) - fi = ce * self.n_routed_experts - aux_loss = (Pi * fi).sum() * self.alpha - save_load_balancing_loss((aux_loss, Pi, fi, self.alpha)) - else: - aux_loss = None + # Pi = scores_for_aux.mean(0) + # fi = ce * self.n_routed_experts + # aux_loss = (Pi * fi).sum() * self.alpha + # save_load_balancing_loss((aux_loss, Pi, fi, self.alpha)) + # else: + aux_loss = None return topk_idx, topk_weight, aux_loss # Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py @@ -119,20 +120,22 @@ class MOEFeedForwardSwiGLU(nn.Module): topk_idx, topk_weight, aux_loss = self.gate(x) x = x.view(-1, x.shape[-1]) flat_topk_idx = topk_idx.view(-1) - if self.training: - x = x.repeat_interleave(self.num_activated_experts, dim=0) - y = torch.empty_like(x, dtype=wtype) - for i, expert in enumerate(self.experts): - y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) - y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) - y = y.view(*orig_shape).to(dtype=wtype) - #y = AddAuxiliaryLoss.apply(y, aux_loss) - else: - y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) + # this was in original and memory leaks, not needed + # if self.training: + # x = x.repeat_interleave(self.num_activated_experts, dim=0) + # y = torch.empty_like(x, dtype=wtype) + # for i, expert in enumerate(self.experts): + # y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype) + # y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) + # y = y.view(*orig_shape).to(dtype=wtype) + # #y = AddAuxiliaryLoss.apply(y, aux_loss) + # else: + # y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape) y = y + self.shared_experts(identity) return y - @torch.no_grad() + # @torch.no_grad() def moe_infer(self, x, flat_expert_indices, flat_expert_weights): expert_cache = torch.zeros_like(x) idxs = flat_expert_indices.argsort()