mirror of
https://github.com/ostris/ai-toolkit.git
synced 2026-02-19 11:59:58 +00:00
155 lines
6.3 KiB
Python
155 lines
6.3 KiB
Python
import math
|
|
import torch
|
|
from torch import nn
|
|
import torch.nn.functional as F
|
|
from .attention import FeedForwardSwiGLU
|
|
from torch.distributed.nn.functional import all_gather
|
|
|
|
_LOAD_BALANCING_LOSS = []
|
|
def save_load_balancing_loss(loss):
|
|
global _LOAD_BALANCING_LOSS
|
|
_LOAD_BALANCING_LOSS.append(loss)
|
|
|
|
def clear_load_balancing_loss():
|
|
global _LOAD_BALANCING_LOSS
|
|
_LOAD_BALANCING_LOSS.clear()
|
|
|
|
def get_load_balancing_loss():
|
|
global _LOAD_BALANCING_LOSS
|
|
return _LOAD_BALANCING_LOSS
|
|
|
|
def batched_load_balancing_loss():
|
|
aux_losses_arr = get_load_balancing_loss()
|
|
alpha = aux_losses_arr[0][-1]
|
|
Pi = torch.stack([ent[1] for ent in aux_losses_arr], dim=0)
|
|
fi = torch.stack([ent[2] for ent in aux_losses_arr], dim=0)
|
|
|
|
fi_list = all_gather(fi)
|
|
fi = torch.stack(fi_list, 0).mean(0)
|
|
|
|
aux_loss = (Pi * fi).sum(-1).mean() * alpha
|
|
return aux_loss
|
|
|
|
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
|
class MoEGate(nn.Module):
|
|
def __init__(self, embed_dim, num_routed_experts=4, num_activated_experts=2, aux_loss_alpha=0.01):
|
|
super().__init__()
|
|
self.top_k = num_activated_experts
|
|
self.n_routed_experts = num_routed_experts
|
|
|
|
self.scoring_func = 'softmax'
|
|
self.alpha = aux_loss_alpha
|
|
self.seq_aux = False
|
|
|
|
# topk selection algorithm
|
|
self.norm_topk_prob = False
|
|
self.gating_dim = embed_dim
|
|
self.weight = nn.Parameter(torch.empty((self.n_routed_experts, self.gating_dim)))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self) -> None:
|
|
import torch.nn.init as init
|
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
|
|
def forward(self, hidden_states):
|
|
bsz, seq_len, h = hidden_states.shape
|
|
# print(bsz, seq_len, h)
|
|
### compute gating score
|
|
hidden_states = hidden_states.view(-1, h)
|
|
logits = F.linear(hidden_states, self.weight, None)
|
|
if self.scoring_func == 'softmax':
|
|
scores = logits.softmax(dim=-1)
|
|
else:
|
|
raise NotImplementedError(f'insupportable scoring function for MoE gating: {self.scoring_func}')
|
|
|
|
### select top-k experts
|
|
topk_weight, topk_idx = torch.topk(scores, k=self.top_k, dim=-1, sorted=False)
|
|
|
|
### norm gate to sum 1
|
|
if self.top_k > 1 and self.norm_topk_prob:
|
|
denominator = topk_weight.sum(dim=-1, keepdim=True) + 1e-20
|
|
topk_weight = topk_weight / denominator
|
|
|
|
### expert-level computation auxiliary loss
|
|
if self.training and self.alpha > 0.0:
|
|
scores_for_aux = scores
|
|
aux_topk = self.top_k
|
|
# always compute aux loss based on the naive greedy topk method
|
|
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
|
if self.seq_aux:
|
|
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
|
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device)
|
|
ce.scatter_add_(1, topk_idx_for_aux_loss, torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device)).div_(seq_len * aux_topk / self.n_routed_experts)
|
|
aux_loss = (ce * scores_for_seq_aux.mean(dim = 1)).sum(dim = 1).mean() * self.alpha
|
|
else:
|
|
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
|
ce = mask_ce.float().mean(0)
|
|
|
|
Pi = scores_for_aux.mean(0)
|
|
fi = ce * self.n_routed_experts
|
|
aux_loss = (Pi * fi).sum() * self.alpha
|
|
save_load_balancing_loss((aux_loss, Pi, fi, self.alpha))
|
|
else:
|
|
aux_loss = None
|
|
return topk_idx, topk_weight, aux_loss
|
|
|
|
# Modified from https://github.com/deepseek-ai/DeepSeek-V3/blob/main/inference/model.py
|
|
class MOEFeedForwardSwiGLU(nn.Module):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
hidden_dim: int,
|
|
num_routed_experts: int,
|
|
num_activated_experts: int,
|
|
):
|
|
super().__init__()
|
|
self.shared_experts = FeedForwardSwiGLU(dim, hidden_dim // 2)
|
|
self.experts = nn.ModuleList([FeedForwardSwiGLU(dim, hidden_dim) for i in range(num_routed_experts)])
|
|
self.gate = MoEGate(
|
|
embed_dim = dim,
|
|
num_routed_experts = num_routed_experts,
|
|
num_activated_experts = num_activated_experts
|
|
)
|
|
self.num_activated_experts = num_activated_experts
|
|
|
|
def forward(self, x):
|
|
wtype = x.dtype
|
|
identity = x
|
|
orig_shape = x.shape
|
|
topk_idx, topk_weight, aux_loss = self.gate(x)
|
|
x = x.view(-1, x.shape[-1])
|
|
flat_topk_idx = topk_idx.view(-1)
|
|
if self.training:
|
|
x = x.repeat_interleave(self.num_activated_experts, dim=0)
|
|
y = torch.empty_like(x, dtype=wtype)
|
|
for i, expert in enumerate(self.experts):
|
|
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(dtype=wtype)
|
|
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
|
y = y.view(*orig_shape).to(dtype=wtype)
|
|
#y = AddAuxiliaryLoss.apply(y, aux_loss)
|
|
else:
|
|
y = self.moe_infer(x, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
|
|
y = y + self.shared_experts(identity)
|
|
return y
|
|
|
|
@torch.no_grad()
|
|
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
|
expert_cache = torch.zeros_like(x)
|
|
idxs = flat_expert_indices.argsort()
|
|
tokens_per_expert = flat_expert_indices.bincount().cpu().numpy().cumsum(0)
|
|
token_idxs = idxs // self.num_activated_experts
|
|
for i, end_idx in enumerate(tokens_per_expert):
|
|
start_idx = 0 if i == 0 else tokens_per_expert[i-1]
|
|
if start_idx == end_idx:
|
|
continue
|
|
expert = self.experts[i]
|
|
exp_token_idx = token_idxs[start_idx:end_idx]
|
|
expert_tokens = x[exp_token_idx]
|
|
expert_out = expert(expert_tokens)
|
|
expert_out.mul_(flat_expert_weights[idxs[start_idx:end_idx]])
|
|
|
|
# for fp16 and other dtype
|
|
expert_cache = expert_cache.to(expert_out.dtype)
|
|
expert_cache.scatter_reduce_(0, exp_token_idx.view(-1, 1).repeat(1, x.shape[-1]), expert_out, reduce='sum')
|
|
return expert_cache
|