diff --git a/kt-kernel/python/sft/arch.py b/kt-kernel/python/sft/arch.py index 80c88136..8e60c581 100644 --- a/kt-kernel/python/sft/arch.py +++ b/kt-kernel/python/sft/arch.py @@ -12,6 +12,7 @@ from __future__ import annotations import logging from dataclasses import dataclass +import torch import torch.nn as nn logger = logging.getLogger(__name__) @@ -136,6 +137,21 @@ def get_moe_module(layer: nn.Module, moe_config: MOEArchConfig) -> nn.Module | N return moe_module +def detect_fused_experts(experts: nn.Module) -> bool: + """Detect if experts module uses the transformers v5 fused format. + + Fused format: a single Module with ``gate_up_proj`` [E, 2I, H] and + ``down_proj`` [E, H, I] 3-D tensors instead of a ModuleList of Linear experts. + """ + if experts is None: + return False + gate_up = getattr(experts, "gate_up_proj", None) + down = getattr(experts, "down_proj", None) + if isinstance(gate_up, torch.Tensor) and isinstance(down, torch.Tensor): + return gate_up.dim() == 3 and down.dim() == 3 + return False + + def _get_layers_prefix(config) -> str: arch = config.architectures[0] if getattr(config, "architectures", None) else "" if any(x in arch for x in ["Deepseek", "Qwen", "Mixtral", "Llama"]): diff --git a/kt-kernel/python/sft/autograd.py b/kt-kernel/python/sft/autograd.py index 0264e9de..9b2934e4 100644 --- a/kt-kernel/python/sft/autograd.py +++ b/kt-kernel/python/sft/autograd.py @@ -76,7 +76,7 @@ class KTMoEFunction(torch.autograd.Function): # Rank 0: sync CPU result and split by real lengths if rank == 0: - cpu_output = wrapper.sync_forward(output_device=original_device) + cpu_output = wrapper.sync_forward_sft(output_device=original_device) cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, hidden_size) offsets = _qlen_offsets(all_qlens_list) scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)] @@ -96,7 +96,7 @@ class KTMoEFunction(torch.autograd.Function): del output_flat elif wrapper is not None: # Single-GPU: sync directly - cpu_output = wrapper.sync_forward(output_device=original_device) + cpu_output = wrapper.sync_forward_sft(output_device=original_device) output = cpu_output.view(batch_size, seq_len, hidden_size).to(dtype=original_dtype) else: # Broadcast-only rank (no wrapper) diff --git a/kt-kernel/python/sft/layer.py b/kt-kernel/python/sft/layer.py index c20d5d93..fa889721 100644 --- a/kt-kernel/python/sft/layer.py +++ b/kt-kernel/python/sft/layer.py @@ -82,10 +82,6 @@ class KTMoELayerWrapper(nn.Module): # PEFT LoRA tracking (set by kt_adapt_peft_lora) # _peft_lora_modules: {expert_idx: {proj_name: (lora_A, lora_B)}} self._peft_lora_modules: dict[int, dict[str, tuple[nn.Module, nn.Module]]] | None = None - self._peft_lora_rank: int = 0 - self._peft_lora_alpha: float = 0.0 - self._skip_lora: bool = False # True when using SkipLoRA backend (no LoRA on experts) - self._lora_pointers_dirty = False def _apply(self, fn, recurse=True): @@ -210,7 +206,7 @@ class KTMoELayerWrapper(nn.Module): if rank == 0: if self.wrapper is None: raise RuntimeError("Rank0 wrapper is required in distributed KT overlap path.") - cpu_output = self.wrapper.sync_forward(output_device=original_device) + cpu_output = self.wrapper.sync_forward_sft(output_device=original_device) cpu_output = cpu_output.to(dtype=original_dtype).view(total_qlen, self.hidden_size) offsets = _qlen_offsets(all_qlens_list) scatter_list = [cpu_output[offsets[i] : offsets[i + 1]].contiguous() for i in range(world_size)] @@ -231,7 +227,7 @@ class KTMoELayerWrapper(nn.Module): return output if self.wrapper is not None: - cpu_output = self.wrapper.sync_forward(output_device=original_device) + cpu_output = self.wrapper.sync_forward_sft(output_device=original_device) output = cpu_output.view(batch_size, seq_len, self.hidden_size).to(dtype=original_dtype) return output @@ -263,7 +259,18 @@ class KTMoELayerWrapper(nn.Module): topk_weights = topk_weights.to(torch.bfloat16) return topk_ids, topk_weights - router_logits = router(hidden_states.view(-1, self.hidden_size)) + router_output = router(hidden_states.view(-1, self.hidden_size)) + # transformers v5 TopKRouter returns (router_logits, router_scores, router_indices) + # directly — scores/indices are already topk-normalized. + if isinstance(router_output, tuple): + if len(router_output) >= 3: + _logits, topk_weights, topk_ids = router_output[0], router_output[1], router_output[2] + if topk_weights.is_floating_point(): + topk_weights = topk_weights.to(torch.bfloat16) + return topk_ids, topk_weights + router_output = router_output[0] + + router_logits = router_output routing_weights = F.softmax(router_logits, dim=-1, dtype=torch.float32) topk_weights, topk_ids = torch.topk(routing_weights, self.moe_config.num_experts_per_tok, dim=-1) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) @@ -328,7 +335,7 @@ class KTMoELayerWrapper(nn.Module): all_hs = torch.cat(gathered_hs, dim=0) all_ids = torch.cat(gathered_ids, dim=0) all_wts = torch.cat(gathered_wts, dim=0) - self.wrapper.submit_forward( + self.wrapper.submit_forward_sft( all_hs, all_ids, all_wts, @@ -357,7 +364,7 @@ class KTMoELayerWrapper(nn.Module): submit_hs = input_flat.detach() submit_ids = expert_ids.detach() submit_wts = weights.detach() - self.wrapper.submit_forward( + self.wrapper.submit_forward_sft( submit_hs, submit_ids, submit_wts, diff --git a/kt-kernel/python/sft/lora.py b/kt-kernel/python/sft/lora.py index d949edf9..5a594ec8 100644 --- a/kt-kernel/python/sft/lora.py +++ b/kt-kernel/python/sft/lora.py @@ -118,6 +118,10 @@ def get_kt_lora_params(model: nn.Module) -> list[nn.Parameter]: params.append(lora_A.weight) if hasattr(lora_B, 'weight') and lora_B.weight.requires_grad: params.append(lora_B.weight) + # Fused expert LoRA parameters (KT-managed, not PEFT) + fused_params = getattr(wrapper, "_fused_expert_lora_params", None) + if fused_params is not None: + params.extend(fused_params) # lora_experts parameters (separate feature) if getattr(wrapper, "lora_experts", None) is not None: params.extend(wrapper.lora_experts.parameters()) @@ -163,7 +167,34 @@ def kt_adapt_peft_lora(model: nn.Module) -> None: experts_attr = getattr(wrapper, "_experts_attr", "experts") experts = getattr(wrapper, experts_attr, None) - if experts is None or len(experts) == 0: + if experts is None: + continue + + # Fused experts (transformers v5): PEFT cannot auto-attach LoRA to packed + # nn.Parameter tensors. Create KT-managed LoRA buffers with proper init, + # wrap as nn.Parameter for optimizer, and pre-assign .grad for C++ backward. + if getattr(wrapper, "_fused_experts", False): + lora_rank = getattr(wrapper, "_lora_rank", 1) + lora_buffers, lora_grad_buffers, lora_params = _create_fused_expert_lora_buffers( + wrapper, moe_config, lora_rank, torch.bfloat16, + ) + + if is_rank_0 and wrapper.wrapper is not None: + all_buffers = {} + all_buffers.update(lora_buffers) + all_buffers.update(lora_grad_buffers) + wrapper.wrapper.init_lora_weights(**all_buffers) + logger.info( + f"[kt_adapt_peft_lora] Layer {layer_idx}: fused expert LoRA " + f"(r={lora_rank}, E={moe_config.expert_num})" + ) + + wrapper._fused_expert_lora_params = lora_params + wrapper._peft_lora_modules = None + adapted_count += 1 + continue + + if len(experts) == 0: continue # Collect references to PEFT LoRA modules for each expert @@ -197,21 +228,11 @@ def kt_adapt_peft_lora(model: nn.Module) -> None: # Store PEFT LoRA references on wrapper wrapper._peft_lora_modules = peft_lora_modules - # SkipLoRA mode: if no LoRA found on experts, skip buffer creation if not peft_lora_modules: - if getattr(wrapper, '_skip_lora', False): - logger.info( - f"[kt_adapt_peft_lora] Layer {layer_idx}: SkipLoRA mode, " - f"no PEFT LoRA on experts — skipping LoRA buffer creation" - ) - adapted_count += 1 - continue - else: - raise RuntimeError( - f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. " - f"If you intend to train without expert LoRA, use a SkipLoRA backend " - f"(e.g., kt_backend: AMXINT8_SkipLoRA)." - ) + raise RuntimeError( + f"[kt_adapt_peft_lora] Layer {layer_idx}: No PEFT LoRA found on any expert. " + f"Check that PEFT lora_target includes expert modules." + ) # Allocate contiguous bf16 buffers and populate with initial PEFT values (all ranks) lora_buffers = _create_lora_view_buffers(peft_lora_modules, moe_config, torch.bfloat16) @@ -243,6 +264,8 @@ def kt_adapt_peft_lora(model: nn.Module) -> None: experts = getattr(wrapper, experts_attr, None) if experts is None: continue + if getattr(wrapper, "_fused_experts", False): + continue for expert in experts: for param_name, param in list(expert.named_parameters()): if param.requires_grad: @@ -372,6 +395,62 @@ def _create_lora_grad_buffers( return buffers +def _create_fused_expert_lora_buffers( + wrapper, + moe_config: MOEArchConfig, + lora_rank: int, + dtype: torch.dtype = torch.bfloat16, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor], list[nn.Parameter]]: + """ + Create KT-managed LoRA buffers for fused expert modules. + + Fused experts store weights as 3D parameters (gate_up_proj [E, 2I, H], down_proj [E, H, I]) + rather than per-expert nn.Linear modules. PEFT can't attach per-expert LoRA to these, + so we create our own LoRA buffers that the C++ kernel reads/writes directly. + + Returns: + (lora_buffers, lora_grad_buffers, lora_params): + - lora_buffers: dict of weight buffers for C++ init_lora_weights() + - lora_grad_buffers: dict of grad buffers for C++ backward + - lora_params: list of nn.Parameter wrappers for the optimizer + """ + E = moe_config.expert_num + I = moe_config.intermediate_size + H = wrapper.hidden_size + r = lora_rank + + logger.info(f"[_create_fused_expert_lora_buffers] E={E}, I={I}, H={H}, r={r}") + + lora_buffers = { + "gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"), + "gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"), + "up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"), + "up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"), + "down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"), + "down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"), + } + + for key in ("gate_lora_a", "up_lora_a", "down_lora_a"): + nn.init.kaiming_uniform_(lora_buffers[key].view(E * r, -1), a=math.sqrt(5)) + + lora_grad_buffers = { + "grad_gate_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"), + "grad_gate_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"), + "grad_up_lora_a": torch.zeros(E, r, H, dtype=dtype, device="cpu"), + "grad_up_lora_b": torch.zeros(E, I, r, dtype=dtype, device="cpu"), + "grad_down_lora_a": torch.zeros(E, r, I, dtype=dtype, device="cpu"), + "grad_down_lora_b": torch.zeros(E, H, r, dtype=dtype, device="cpu"), + } + + lora_params = [] + for key in ("gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"): + param = nn.Parameter(lora_buffers[key], requires_grad=True) + param.grad = lora_grad_buffers[f"grad_{key}"] + lora_params.append(param) + + return lora_buffers, lora_grad_buffers, lora_params + + # ============================================================================= # PEFT Weight View Replacement # ============================================================================= @@ -510,15 +589,7 @@ def save_lora_experts_to_adapter(model: nn.Module, output_dir: str) -> None: from safetensors import safe_open from safetensors.torch import save_file - wrappers = getattr(model, "_kt_wrappers", []) - if not wrappers: - base_model = model - for attr in ["base_model", "model"]: - if hasattr(base_model, attr): - base_model = getattr(base_model, attr) - wrappers = getattr(base_model, "_kt_wrappers", []) - if wrappers: - break + wrappers = _find_kt_wrappers(model) or [] if not wrappers: logger.warning("No KT wrappers found, skipping LoRA Experts saving") return @@ -568,25 +639,80 @@ def save_kt_moe_to_adapter(model: nn.Module, output_dir: str) -> None: Note: Per-expert PEFT LoRA is saved by PEFT directly, not here. This function only handles lora_experts (a separate feature). """ - wrappers = getattr(model, "_kt_wrappers", []) - if not wrappers: - base_model = model - for attr in ["base_model", "model"]: - if hasattr(base_model, attr): - base_model = getattr(base_model, attr) - wrappers = getattr(base_model, "_kt_wrappers", []) - if wrappers: - break + wrappers = _find_kt_wrappers(model) or [] if not wrappers: logger.info("[save_kt_moe] No KT wrappers found, skipping") return has_lora_experts = any(w.lora_experts is not None for w in wrappers) + has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers) if has_lora_experts: save_lora_experts_to_adapter(model, output_dir) - else: - logger.info("[save_kt_moe] No lora_experts in KT wrappers") + + if has_fused_lora: + _save_fused_expert_lora(wrappers, output_dir) + + if not has_lora_experts and not has_fused_lora: + logger.info("[save_kt_moe] No lora_experts or fused expert LoRA in KT wrappers") + + +def _save_fused_expert_lora(wrappers: list, output_dir: str) -> None: + """Save fused expert LoRA params to a safetensors file.""" + from safetensors.torch import save_file + + names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"] + tensors = {} + for w in wrappers: + fused = getattr(w, "_fused_expert_lora_params", None) + if fused is None: + continue + for param, name in zip(fused, names): + key = f"layers.{w.layer_idx}.experts.{name}" + tensors[key] = param.data.clone() + + if tensors: + path = os.path.join(output_dir, "fused_expert_lora.safetensors") + save_file(tensors, path) + logger.info(f"[save_kt_moe] Saved {len(tensors)} fused expert LoRA tensors to {path}") + + +def _load_fused_expert_lora(wrappers: list, adapter_path: str) -> None: + """Load fused expert LoRA params from a safetensors file into existing wrapper buffers.""" + path = os.path.join(adapter_path, "fused_expert_lora.safetensors") + if not os.path.isfile(path): + logger.warning(f"No fused_expert_lora.safetensors found at {adapter_path}") + return + + from safetensors.torch import load_file + + saved = load_file(path) + names = ["gate_lora_a", "gate_lora_b", "up_lora_a", "up_lora_b", "down_lora_a", "down_lora_b"] + wrapper_map = {w.layer_idx: w for w in wrappers} + loaded_count = 0 + + for key, tensor in saved.items(): + parts = key.split(".") + if len(parts) != 4 or parts[0] != "layers" or parts[2] != "experts": + logger.warning(f"Unexpected key in fused_expert_lora.safetensors: {key}") + continue + layer_idx = int(parts[1]) + name = parts[3] + if name not in names: + continue + + wrapper = wrapper_map.get(layer_idx) + if wrapper is None: + continue + fused = getattr(wrapper, "_fused_expert_lora_params", None) + if fused is None: + continue + + param_idx = names.index(name) + fused[param_idx].data.copy_(tensor) + loaded_count += 1 + + logger.info(f"[_load_fused_expert_lora] Loaded {loaded_count} tensors from {path}") def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None: @@ -595,15 +721,7 @@ def load_lora_experts_from_adapter(model: nn.Module, adapter_path: str) -> None: """ from safetensors import safe_open - wrappers = getattr(model, "_kt_wrappers", []) - if not wrappers: - base_model = model - for attr in ["base_model", "model"]: - if hasattr(base_model, attr): - base_model = getattr(base_model, attr) - wrappers = getattr(base_model, "_kt_wrappers", []) - if wrappers: - break + wrappers = _find_kt_wrappers(model) or [] if not wrappers: logger.warning("No KT wrappers found, skipping LoRA Experts loading") return @@ -667,22 +785,19 @@ def load_kt_moe_from_adapter(model: nn.Module, adapter_path: str) -> None: Note: Per-expert PEFT LoRA is loaded by PEFT directly, not here. This function only handles lora_experts (a separate feature). """ - wrappers = getattr(model, "_kt_wrappers", []) - if not wrappers: - base_model = model - for attr in ["base_model", "model"]: - if hasattr(base_model, attr): - base_model = getattr(base_model, attr) - wrappers = getattr(base_model, "_kt_wrappers", []) - if wrappers: - break + wrappers = _find_kt_wrappers(model) or [] if not wrappers: logger.warning("No KT wrappers found, skipping KT MoE loading") return has_lora_experts = any(w.lora_experts is not None for w in wrappers) + has_fused_lora = any(getattr(w, "_fused_expert_lora_params", None) is not None for w in wrappers) if has_lora_experts: load_lora_experts_from_adapter(model, adapter_path) - else: - logger.info("No lora_experts in KT wrappers (PEFT LoRA is loaded by PEFT directly)") + + if has_fused_lora: + _load_fused_expert_lora(wrappers, adapter_path) + + if not has_lora_experts and not has_fused_lora: + logger.info("No lora_experts or fused expert LoRA in KT wrappers") diff --git a/kt-kernel/python/sft/weights.py b/kt-kernel/python/sft/weights.py index b0bbb6a2..207f8e4f 100644 --- a/kt-kernel/python/sft/weights.py +++ b/kt-kernel/python/sft/weights.py @@ -40,8 +40,28 @@ def extract_moe_weights( Returns (gate_proj, up_proj, down_proj) with shape [expert_num, out_features, in_features]. + + Supports two formats: + - ModuleList of Linear experts (transformers v4 style) + - Fused Parameters (transformers v5 style): single module with + ``gate_up_proj`` [E, 2*I, H] and ``down_proj`` [E, H, I] tensors. """ + from .arch import detect_fused_experts + experts = getattr(moe_module, moe_config.experts_attr) + + # Fused format (transformers v5): a single nn.Module with gate_up_proj/down_proj tensors + if detect_fused_experts(experts): + gate_up = getattr(experts, "gate_up_proj").data + down_fused = getattr(experts, "down_proj").data + # gate_up_proj is [E, 2*I, H], split into gate [E, I, H] and up [E, I, H] + intermediate = gate_up.shape[1] // 2 + gate_proj = gate_up[:, :intermediate, :].contiguous() + up_proj = gate_up[:, intermediate:, :].contiguous() + # down_proj is already [E, H, I] + down_proj = down_fused.contiguous() + return gate_proj, up_proj, down_proj + gate_name, up_name, down_name = moe_config.weight_names gather_params: list[torch.nn.Parameter] = [] @@ -92,10 +112,27 @@ def _clear_original_expert_weights(moe_module: nn.Module, moe_config: MOEArchCon """ Clear original expert weights to free memory after KT weights are loaded. """ + from .arch import detect_fused_experts + experts = getattr(moe_module, moe_config.experts_attr, None) if experts is None: return + # Fused format: replace gate_up_proj/down_proj tensors with zero-storage placeholders + if detect_fused_experts(experts): + for name in ("gate_up_proj", "down_proj"): + param = getattr(experts, name, None) + if not isinstance(param, torch.nn.Parameter): + continue + original_dtype = param.dtype + tiny_storage = torch.UntypedStorage(1, device="cpu") + fake_tensor = torch.tensor([], dtype=original_dtype, device="cpu").set_( + tiny_storage, storage_offset=0, size=param.shape, + stride=[0] * len(param.shape), + ) + experts._parameters[name] = nn.Parameter(fake_tensor, requires_grad=False) + return + def _iter_weight_params(): for expert in experts: for weight_name in moe_config.weight_names: diff --git a/kt-kernel/python/sft/wrapper.py b/kt-kernel/python/sft/wrapper.py index cc060826..4b29bfd7 100644 --- a/kt-kernel/python/sft/wrapper.py +++ b/kt-kernel/python/sft/wrapper.py @@ -264,12 +264,17 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT model_container, layers = _get_model_container_and_layers(model, purpose="wrapping") logger.info(f"Total layers={len(layers)}, is_rank_0={is_rank_0}") + from .arch import detect_fused_experts as _detect_fused + for layer_idx, layer in enumerate(layers): moe_module = get_moe_module(layer, moe_config) if moe_module is None: continue - logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method})") + _layer_experts = getattr(moe_module, moe_config.experts_attr, None) + _layer_is_fused = _detect_fused(_layer_experts) + + logger.debug(f"Wrapping MoE layer {layer_idx} (method={kt_method}, fused={_layer_is_fused})") # Only rank 0 loads weights and initializes KT kernel gate_proj, up_proj, down_proj = None, None, None @@ -312,7 +317,6 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT num_experts_per_tok=moe_config.num_experts_per_tok, hidden_size=hidden_size, moe_intermediate_size=moe_config.intermediate_size, - gpu_experts_mask=None, num_gpu_experts=0, cpuinfer_threads=getattr(cfg, "kt_num_threads", 1), threadpool_count=threadpool_count, @@ -370,7 +374,8 @@ def wrap_moe_layers_with_kt_wrapper(model: nn.Module, kt_plugin: Any) -> list[KT layer_idx=layer_idx, lora_experts=lora_experts, ) - layer_wrapper._skip_lora = "SkipLoRA" in kt_method + layer_wrapper._fused_experts = _layer_is_fused + layer_wrapper._lora_rank = lora_rank setattr(layer, moe_config.moe_layer_attr, layer_wrapper) # Base weights have been copied into the C++ kernel's internal BufferB format.