mirror of
https://github.com/kvcache-ai/ktransformers.git
synced 2026-06-08 23:37:58 +00:00
904 lines
37 KiB
Python
904 lines
37 KiB
Python
# AMX SFT MoE Wrapper classes for CPU-based fine-tuning operations
|
|
# SPDX-License-Identifier: Apache-2.0
|
|
|
|
"""
|
|
AMX-based SFT MoE Wrapper implementation.
|
|
|
|
Supports quantization methods:
|
|
- AMXBF16_SFT: BF16 precision training
|
|
- AMXINT8_SFT: INT8 quantization training
|
|
- AMXINT4_SFT: INT4 quantization training
|
|
- AMXINT4_KGroup_SFT: INT4 K-Group quantization training (AWQ/K2)
|
|
"""
|
|
|
|
import ctypes
|
|
import torch
|
|
from typing import Dict, Tuple, Optional, List
|
|
|
|
from kt_kernel_ext.moe import MOESFTConfig
|
|
|
|
from .loader import BF16SafeTensorLoader, SafeTensorLoader
|
|
|
|
try:
|
|
from kt_kernel_ext.moe import (
|
|
AMXBF16_SFT_MOE,
|
|
AMXInt8_SFT_MOE,
|
|
AMXInt4_SFT_MOE,
|
|
# AMXInt4_1_SFT_MOE,
|
|
# AMXInt4_1KGroup_SFT_MOE,
|
|
# AMXInt4_KGroup_SFT_MOE,
|
|
# SkipLoRA variants (skip all LoRA computation in backward)
|
|
AMXBF16_SFT_MOE_SkipLoRA,
|
|
AMXInt8_SFT_MOE_SkipLoRA,
|
|
AMXInt4_SFT_MOE_SkipLoRA,
|
|
# AMXInt4_1_SFT_MOE_SkipLoRA,
|
|
# AMXInt4_1KGroup_SFT_MOE_SkipLoRA,
|
|
# AMXInt4_KGroup_SFT_MOE_SkipLoRA,
|
|
)
|
|
|
|
_HAS_AMX_SFT_SUPPORT = True
|
|
except (ImportError, AttributeError):
|
|
_HAS_AMX_SFT_SUPPORT = False
|
|
AMXBF16_SFT_MOE = None
|
|
AMXInt8_SFT_MOE = None
|
|
AMXInt4_SFT_MOE = None
|
|
# AMXInt4_1_SFT_MOE = None
|
|
# AMXInt4_1KGroup_SFT_MOE = None
|
|
# AMXInt4_KGroup_SFT_MOE = None
|
|
# SkipLoRA variants
|
|
AMXBF16_SFT_MOE_SkipLoRA = None
|
|
AMXInt8_SFT_MOE_SkipLoRA = None
|
|
AMXInt4_SFT_MOE_SkipLoRA = None
|
|
# AMXInt4_1_SFT_MOE_SkipLoRA = None
|
|
# AMXInt4_1KGroup_SFT_MOE_SkipLoRA = None
|
|
# AMXInt4_KGroup_SFT_MOE_SkipLoRA = None
|
|
|
|
from ..experts_sft import BaseSFTMoEWrapper, KExpertsSFTBuffer
|
|
|
|
|
|
# Mapping from method string to C++ SFT MOE class
|
|
_SFT_METHOD_TO_CLASS = {
|
|
"AMXBF16_SFT": AMXBF16_SFT_MOE,
|
|
"AMXINT8_SFT": AMXInt8_SFT_MOE,
|
|
"AMXINT4_SFT": AMXInt4_SFT_MOE,
|
|
# "AMXINT4_1_SFT": AMXInt4_1_SFT_MOE,
|
|
# "AMXINT4_KGroup_SFT": AMXInt4_KGroup_SFT_MOE,
|
|
# "AMXINT4_1KGroup_SFT": AMXInt4_1KGroup_SFT_MOE,
|
|
# SkipLoRA variants (skip all LoRA computation in backward, only compute base weight grad_input)
|
|
"AMXBF16_SFT_SkipLoRA": AMXBF16_SFT_MOE_SkipLoRA,
|
|
"AMXINT8_SFT_SkipLoRA": AMXInt8_SFT_MOE_SkipLoRA,
|
|
"AMXINT4_SFT_SkipLoRA": AMXInt4_SFT_MOE_SkipLoRA,
|
|
# "AMXINT4_1_SFT_SkipLoRA": AMXInt4_1_SFT_MOE_SkipLoRA,
|
|
# "AMXINT4_KGroup_SFT_SkipLoRA": AMXInt4_KGroup_SFT_MOE_SkipLoRA,
|
|
# "AMXINT4_1KGroup_SFT_SkipLoRA": AMXInt4_1KGroup_SFT_MOE_SkipLoRA,
|
|
}
|
|
|
|
|
|
class AMXSFTMoEWrapper(BaseSFTMoEWrapper):
|
|
"""
|
|
AMX-based SFT MoE wrapper implementation.
|
|
|
|
Supports BF16, INT8, INT4, and INT4 K-Group quantization methods
|
|
for supervised fine-tuning with LoRA adapters.
|
|
|
|
Design Note (forward_sft vs forward):
|
|
forward_sft() is implemented independently from inference forward() because:
|
|
1. Different requirements: inference optimizes for latency, SFT requires gradient correctness
|
|
2. Safety: inference optimizations (deferred experts, async execution) would break SFT gradients
|
|
3. Most reusable optimizations are already in C++ layer (via inheritance)
|
|
4. Manual copying of useful optimizations is safer and more maintainable
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
layer_idx: int,
|
|
num_experts: int,
|
|
num_experts_per_tok: int,
|
|
hidden_size: int,
|
|
moe_intermediate_size: int,
|
|
num_gpu_experts: int,
|
|
cpuinfer_threads: int,
|
|
threadpool_count: int,
|
|
weight_path: str,
|
|
chunked_prefill_size: int,
|
|
# SFT-specific parameters
|
|
lora_rank: int = 16,
|
|
lora_alpha: float = 32.0,
|
|
max_cache_depth: int = 1,
|
|
method: str = "AMXBF16_SFT",
|
|
# Quantization config (for K-Group methods)
|
|
group_size: int = 128,
|
|
zero_point: bool = True,
|
|
):
|
|
"""
|
|
Initialize AMX SFT MoE Wrapper.
|
|
|
|
Args:
|
|
layer_idx: Layer index
|
|
num_experts: Total number of experts
|
|
num_experts_per_tok: Number of experts per token (top-k)
|
|
hidden_size: Hidden dimension size
|
|
moe_intermediate_size: MoE intermediate size
|
|
num_gpu_experts: Number of experts on GPU (usually 0 for SFT)
|
|
cpuinfer_threads: Number of CPU inference threads
|
|
threadpool_count: Number of NUMA subpools (TP count)
|
|
weight_path: Path to weights
|
|
chunked_prefill_size: Maximum prefill chunk size
|
|
lora_rank: LoRA rank (r)
|
|
lora_alpha: LoRA scaling factor (alpha)
|
|
max_cache_depth: Maximum forward cache depth
|
|
method: AMX quantization method for SFT
|
|
group_size: Quantization group size (for K-Group methods)
|
|
zero_point: Whether to use zero point quantization (for K-Group methods)
|
|
"""
|
|
if not _HAS_AMX_SFT_SUPPORT:
|
|
raise RuntimeError(
|
|
"AMX SFT backend not available. kt_kernel_ext was not compiled with AMX SFT support.\n"
|
|
"Please recompile with AMX SFT enabled."
|
|
)
|
|
|
|
# Initialize base class
|
|
super().__init__(
|
|
layer_idx=layer_idx,
|
|
num_experts=num_experts,
|
|
num_experts_per_tok=num_experts_per_tok,
|
|
hidden_size=hidden_size,
|
|
moe_intermediate_size=moe_intermediate_size,
|
|
num_gpu_experts=num_gpu_experts,
|
|
cpuinfer_threads=cpuinfer_threads,
|
|
threadpool_count=threadpool_count,
|
|
weight_path=weight_path,
|
|
chunked_prefill_size=chunked_prefill_size,
|
|
lora_rank=lora_rank,
|
|
lora_alpha=lora_alpha,
|
|
max_cache_depth=max_cache_depth,
|
|
)
|
|
|
|
# Store method and quantization config
|
|
self.method = method
|
|
self.group_size = group_size
|
|
self.zero_point = zero_point
|
|
|
|
# Validate method
|
|
if method not in _SFT_METHOD_TO_CLASS:
|
|
raise ValueError(
|
|
f"Unknown SFT method: {method}. " f"Supported methods: {list(_SFT_METHOD_TO_CLASS.keys())}"
|
|
)
|
|
|
|
# Get the C++ class for this method
|
|
moe_class = _SFT_METHOD_TO_CLASS[method]
|
|
if moe_class is None:
|
|
raise RuntimeError(f"AMX SFT method '{method}' not available in current build.")
|
|
|
|
# Base weight storage (set via load_weights_from_tensors or loaded from file)
|
|
self.gate_proj: Optional[torch.Tensor] = None
|
|
self.up_proj: Optional[torch.Tensor] = None
|
|
self.down_proj: Optional[torch.Tensor] = None
|
|
|
|
# MoE instance will be created during load_weights
|
|
self._moe_class = moe_class
|
|
|
|
def load_weights(self, physical_to_logical_map_cpu: torch.Tensor) -> None:
|
|
"""
|
|
Load base weights for this layer.
|
|
|
|
Supports two loading modes:
|
|
1. From tensors: Call load_weights_from_tensors() first, then load_weights()
|
|
2. From files: Automatically load from weight_path if base weights not set
|
|
- AMXBF16_SFT: Uses BF16SafeTensorLoader (HuggingFace format)
|
|
- AMXINT8_SFT/AMXINT4_SFT: Uses SafeTensorLoader (pre-quantized format)
|
|
|
|
Args:
|
|
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
|
"""
|
|
if self._weights_loaded:
|
|
return
|
|
|
|
# If base weights not set, try to load from file
|
|
if self.gate_proj is None and not getattr(self, "_use_projs_path", False):
|
|
self._load_base_weights_from_file()
|
|
|
|
# Create MOE SFT config
|
|
config = MOESFTConfig()
|
|
config.expert_num = self.num_experts
|
|
config.num_experts_per_tok = self.num_experts_per_tok
|
|
config.hidden_size = self.hidden_size
|
|
config.intermediate_size = self.moe_intermediate_size
|
|
config.lora_rank = self.lora_rank
|
|
config.lora_alpha = self.lora_alpha
|
|
config.max_cache_depth = self.max_cache_depth
|
|
config.max_len = self.chunked_prefill_size
|
|
config.layer_idx = self.layer_idx
|
|
|
|
# Set base weight pointers
|
|
if getattr(self, "_use_projs_path", False):
|
|
# Pre-quantized per-NUMA per-expert path (INT8/INT4)
|
|
config.gate_projs = self._gate_projs_ptrs
|
|
config.up_projs = self._up_projs_ptrs
|
|
config.down_projs = self._down_projs_ptrs
|
|
config.gate_scales = self._gate_scale_ptrs
|
|
config.up_scales = self._up_scale_ptrs
|
|
config.down_scales = self._down_scale_ptrs
|
|
# Also provide BF16 weight pointers for backward gradient computation.
|
|
# C++ backward needs BF16 base weights to compute gate/up LoRA B gradients
|
|
# through the gated MLP chain (grad_hidden = down_proj^T @ grad_output).
|
|
if getattr(self, "_bf16_gate_proj", None) is not None:
|
|
config.gate_proj = self._bf16_gate_proj.data_ptr()
|
|
config.up_proj = self._bf16_up_proj.data_ptr()
|
|
config.down_proj = self._bf16_down_proj.data_ptr()
|
|
else:
|
|
# Flat BF16 buffer path
|
|
config.gate_proj = self.gate_proj.data_ptr()
|
|
config.up_proj = self.up_proj.data_ptr()
|
|
config.down_proj = self.down_proj.data_ptr()
|
|
|
|
# Set LoRA weight pointers (if initialized)
|
|
if self._lora_initialized:
|
|
config.gate_lora_a = self.gate_lora_a.data_ptr()
|
|
config.gate_lora_b = self.gate_lora_b.data_ptr()
|
|
config.up_lora_a = self.up_lora_a.data_ptr()
|
|
config.up_lora_b = self.up_lora_b.data_ptr()
|
|
config.down_lora_a = self.down_lora_a.data_ptr()
|
|
config.down_lora_b = self.down_lora_b.data_ptr()
|
|
|
|
# Set thread pool
|
|
config.pool = self.cpu_infer.backend_
|
|
|
|
# Set quantization config for K-Group methods
|
|
if self.method in ("AMXINT4_KGroup_SFT", "AMXINT4_1KGroup_SFT"):
|
|
config.quant_config.group_size = self.group_size
|
|
config.quant_config.zero_point = self.zero_point
|
|
|
|
# Create MoE instance
|
|
self.moe = self._moe_class(config)
|
|
|
|
# Load weights
|
|
self.cpu_infer.submit(self.moe.load_weights_task())
|
|
self.cpu_infer.sync()
|
|
|
|
# Warm up
|
|
self.cpu_infer.submit(self.moe.warm_up_task())
|
|
self.cpu_infer.sync()
|
|
|
|
# Release Python-side base weight tensors. C++ has already copied/transformed
|
|
# them into internal BufferB format (backward_bb_pool_) and no longer needs
|
|
# the original bf16 data. Holding these wastes ~1 GB/layer.
|
|
self.gate_proj = None
|
|
self.up_proj = None
|
|
self.down_proj = None
|
|
|
|
if getattr(self, "_bf16_gate_proj", None) is not None:
|
|
self._bf16_gate_proj = None
|
|
self._bf16_up_proj = None
|
|
self._bf16_down_proj = None
|
|
|
|
# Release pre-quantized per-NUMA numpy arrays. C++ has already copied
|
|
# them into internal BufferB format via memcpy in load_weights().
|
|
if getattr(self, "_use_projs_path", False):
|
|
self._gate_weights_per_numa = None
|
|
self._up_weights_per_numa = None
|
|
self._down_weights_per_numa = None
|
|
self._gate_scales_per_numa = None
|
|
self._up_scales_per_numa = None
|
|
self._down_scales_per_numa = None
|
|
self._gate_projs_ptrs = None
|
|
self._up_projs_ptrs = None
|
|
self._down_projs_ptrs = None
|
|
self._gate_scale_ptrs = None
|
|
self._up_scale_ptrs = None
|
|
self._down_scale_ptrs = None
|
|
|
|
self._weights_loaded = True
|
|
|
|
def load_weights_from_tensors(
|
|
self,
|
|
gate_proj: torch.Tensor,
|
|
up_proj: torch.Tensor,
|
|
down_proj: torch.Tensor,
|
|
physical_to_logical_map_cpu: torch.Tensor,
|
|
) -> None:
|
|
"""
|
|
Load weights from BF16/FP16 tensors.
|
|
|
|
This is the recommended way to load weights for SFT, as it supports
|
|
online quantization from full-precision weights.
|
|
|
|
Args:
|
|
gate_proj: Gate projection weights [num_experts, intermediate_size, hidden_size]
|
|
up_proj: Up projection weights [num_experts, intermediate_size, hidden_size]
|
|
down_proj: Down projection weights [num_experts, hidden_size, intermediate_size]
|
|
physical_to_logical_map_cpu: Mapping from physical to logical expert IDs
|
|
"""
|
|
# Store tensors as instance variables to keep them alive
|
|
self.gate_proj = gate_proj.contiguous()
|
|
self.up_proj = up_proj.contiguous()
|
|
self.down_proj = down_proj.contiguous()
|
|
|
|
# Now load weights
|
|
self.load_weights(physical_to_logical_map_cpu)
|
|
|
|
del gate_proj
|
|
del up_proj
|
|
del down_proj
|
|
|
|
|
|
def _load_base_weights_from_file(self) -> None:
|
|
"""
|
|
Load base MoE weights from file based on the SFT method.
|
|
|
|
Loading strategy:
|
|
- AMXBF16_SFT: Use BF16SafeTensorLoader (HuggingFace format, no scales)
|
|
- AMXINT8_SFT/AMXINT4_SFT: Use SafeTensorLoader (pre-quantized format with scales)
|
|
"""
|
|
if not hasattr(self, "weight_path") or self.weight_path is None:
|
|
raise RuntimeError(
|
|
"weight_path not set. Cannot load weights from file. "
|
|
"Either set weight_path or call load_weights_from_tensors() instead."
|
|
)
|
|
|
|
print(
|
|
f"[AMXSFTMoEWrapper] Loading base weights for layer {self.layer_idx} "
|
|
f"from {self.weight_path} using method {self.method}"
|
|
)
|
|
|
|
# Determine loader and base key format based on method
|
|
if "BF16" in self.method:
|
|
# BF16 mode: Load from HuggingFace model path
|
|
loader = BF16SafeTensorLoader(self.weight_path)
|
|
base_key = f"model.layers.{self.layer_idx}"
|
|
else:
|
|
# INT8/INT4 mode: Load from pre-quantized path
|
|
# Note: SafeTensorLoader expects GGUF-style naming (blk.X)
|
|
loader = SafeTensorLoader(self.weight_path)
|
|
base_key = f"blk.{self.layer_idx}"
|
|
|
|
# Load expert weights
|
|
experts_data = loader.load_experts(base_key, device="cpu")
|
|
|
|
# Extract weights (list of tensors per expert -> stacked tensor)
|
|
gate_weights: List[torch.Tensor] = experts_data["gate"]
|
|
up_weights: List[torch.Tensor] = experts_data["up"]
|
|
down_weights: List[torch.Tensor] = experts_data["down"]
|
|
|
|
# Stack expert weights: [num_experts, ...]
|
|
# For BF16: weights are already tensors
|
|
# For SafeTensorLoader: weights might be numpy arrays in nested lists
|
|
if "BF16" in self.method:
|
|
# BF16SafeTensorLoader returns list of tensors
|
|
self.gate_proj = torch.stack(gate_weights, dim=0).contiguous()
|
|
self.up_proj = torch.stack(up_weights, dim=0).contiguous()
|
|
self.down_proj = torch.stack(down_weights, dim=0).contiguous()
|
|
else:
|
|
# SafeTensorLoader returns nested lists [numa_id][expert_id] -> numpy array
|
|
# Keep per-NUMA per-expert arrays for gate_projs/gate_scales path
|
|
import numpy as np
|
|
|
|
num_numa = len(gate_weights)
|
|
|
|
# Store raw per-NUMA per-expert numpy arrays (keep references alive)
|
|
self._gate_weights_per_numa = gate_weights # [numa_id][expert_id] -> np array
|
|
self._up_weights_per_numa = up_weights
|
|
self._down_weights_per_numa = down_weights
|
|
self._gate_scales_per_numa = experts_data["gate_scale"]
|
|
self._up_scales_per_numa = experts_data["up_scale"]
|
|
self._down_scales_per_numa = experts_data["down_scale"]
|
|
|
|
# Build pointer arrays: [[ptr_expert_0, ptr_expert_1, ...], ...] per NUMA
|
|
def _make_ptrs(arrays_per_numa):
|
|
return [
|
|
[
|
|
ctypes.addressof(ctypes.cast(et.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents)
|
|
for et in numa_array
|
|
]
|
|
for numa_array in arrays_per_numa
|
|
]
|
|
|
|
self._gate_projs_ptrs = _make_ptrs(gate_weights)
|
|
self._up_projs_ptrs = _make_ptrs(up_weights)
|
|
self._down_projs_ptrs = _make_ptrs(down_weights)
|
|
self._gate_scale_ptrs = _make_ptrs(experts_data["gate_scale"])
|
|
self._up_scale_ptrs = _make_ptrs(experts_data["up_scale"])
|
|
self._down_scale_ptrs = _make_ptrs(experts_data["down_scale"])
|
|
|
|
# Set gate_proj to None so load_weights() uses gate_projs path
|
|
self.gate_proj = None
|
|
self.up_proj = None
|
|
self.down_proj = None
|
|
self._use_projs_path = True
|
|
|
|
# Close loader handles
|
|
loader.close_all_handles()
|
|
|
|
if getattr(self, "_use_projs_path", False):
|
|
num_numa = len(self._gate_weights_per_numa)
|
|
num_experts = len(self._gate_weights_per_numa[0])
|
|
print(
|
|
f"[AMXSFTMoEWrapper] Loaded pre-quantized weights: "
|
|
f"{num_numa} NUMA nodes, {num_experts} experts per NUMA"
|
|
)
|
|
else:
|
|
print(
|
|
f"[AMXSFTMoEWrapper] Loaded weights: gate_proj={self.gate_proj.shape}, "
|
|
f"up_proj={self.up_proj.shape}, down_proj={self.down_proj.shape}"
|
|
)
|
|
|
|
def init_lora_weights(
|
|
self,
|
|
gate_lora_a: torch.Tensor,
|
|
gate_lora_b: torch.Tensor,
|
|
up_lora_a: torch.Tensor,
|
|
up_lora_b: torch.Tensor,
|
|
down_lora_a: torch.Tensor,
|
|
down_lora_b: torch.Tensor,
|
|
) -> None:
|
|
"""
|
|
Initialize LoRA weights.
|
|
|
|
LoRA output formula:
|
|
lora_output = (input @ A.T @ B.T) * (lora_alpha / lora_rank)
|
|
output = base_output + lora_output
|
|
|
|
Args:
|
|
gate_lora_a: Gate LoRA A matrix [num_experts, lora_rank, hidden_size]
|
|
gate_lora_b: Gate LoRA B matrix [num_experts, intermediate_size, lora_rank]
|
|
up_lora_a: Up LoRA A matrix [num_experts, lora_rank, hidden_size]
|
|
up_lora_b: Up LoRA B matrix [num_experts, intermediate_size, lora_rank]
|
|
down_lora_a: Down LoRA A matrix [num_experts, lora_rank, intermediate_size]
|
|
down_lora_b: Down LoRA B matrix [num_experts, hidden_size, lora_rank]
|
|
"""
|
|
# Validate shapes
|
|
expected_shapes = {
|
|
"gate_lora_a": (self.num_experts, self.lora_rank, self.hidden_size),
|
|
"gate_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank),
|
|
"up_lora_a": (self.num_experts, self.lora_rank, self.hidden_size),
|
|
"up_lora_b": (self.num_experts, self.moe_intermediate_size, self.lora_rank),
|
|
"down_lora_a": (self.num_experts, self.lora_rank, self.moe_intermediate_size),
|
|
"down_lora_b": (self.num_experts, self.hidden_size, self.lora_rank),
|
|
}
|
|
|
|
provided_tensors = {
|
|
"gate_lora_a": gate_lora_a,
|
|
"gate_lora_b": gate_lora_b,
|
|
"up_lora_a": up_lora_a,
|
|
"up_lora_b": up_lora_b,
|
|
"down_lora_a": down_lora_a,
|
|
"down_lora_b": down_lora_b,
|
|
}
|
|
|
|
for name, tensor in provided_tensors.items():
|
|
expected = expected_shapes[name]
|
|
if tensor.shape != expected:
|
|
raise ValueError(f"{name} shape mismatch: expected {expected}, got {tuple(tensor.shape)}")
|
|
|
|
# Store LoRA weights (contiguous for C++ access)
|
|
self.gate_lora_a = gate_lora_a.contiguous()
|
|
self.gate_lora_b = gate_lora_b.contiguous()
|
|
self.up_lora_a = up_lora_a.contiguous()
|
|
self.up_lora_b = up_lora_b.contiguous()
|
|
self.down_lora_a = down_lora_a.contiguous()
|
|
self.down_lora_b = down_lora_b.contiguous()
|
|
|
|
self.grad_gate_lora_a = (
|
|
torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu")
|
|
.zero_()
|
|
.contiguous()
|
|
)
|
|
self.grad_gate_lora_b = (
|
|
torch.empty(
|
|
(self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu"
|
|
)
|
|
.zero_()
|
|
.contiguous()
|
|
)
|
|
|
|
self.grad_up_lora_a = (
|
|
torch.empty((self.num_experts, self.lora_rank, self.hidden_size), dtype=torch.bfloat16, device="cpu")
|
|
.zero_()
|
|
.contiguous()
|
|
)
|
|
self.grad_up_lora_b = (
|
|
torch.empty(
|
|
(self.num_experts, self.moe_intermediate_size, self.lora_rank), dtype=torch.bfloat16, device="cpu"
|
|
)
|
|
.zero_()
|
|
.contiguous()
|
|
)
|
|
|
|
self.grad_down_lora_a = (
|
|
torch.empty(
|
|
(self.num_experts, self.lora_rank, self.moe_intermediate_size), dtype=torch.bfloat16, device="cpu"
|
|
)
|
|
.zero_()
|
|
.contiguous()
|
|
)
|
|
self.grad_down_lora_b = (
|
|
torch.empty((self.num_experts, self.hidden_size, self.lora_rank), dtype=torch.bfloat16, device="cpu")
|
|
.zero_()
|
|
.contiguous()
|
|
)
|
|
|
|
self._lora_initialized = True
|
|
|
|
# If weights already loaded, update LoRA pointers in C++
|
|
if self._weights_loaded and self.moe is not None:
|
|
self.update_lora_weights()
|
|
|
|
def forward_sft(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
save_for_backward: bool = True,
|
|
output_device: Optional[torch.device] = None,
|
|
) -> torch.Tensor:
|
|
"""
|
|
SFT forward pass with optional gradient caching.
|
|
|
|
Optimized for minimal data copying:
|
|
- Accepts GPU tensors directly, copies to pinned buffer in one step
|
|
- Returns directly to output_device without intermediate clone
|
|
|
|
Args:
|
|
hidden_states: Input hidden states [qlen, hidden_size] (any device, will be converted to bf16)
|
|
expert_ids: Expert IDs [qlen, num_experts_per_tok] (any device, will be converted to int64)
|
|
weights: Expert weights [qlen, num_experts_per_tok] (any device, will be converted to float32)
|
|
save_for_backward: Whether to save activations for backward pass
|
|
output_device: Target device for output (None = return CPU tensor without clone, caller must copy immediately)
|
|
|
|
Returns:
|
|
Output hidden states [qlen, hidden_size]
|
|
"""
|
|
if not self._weights_loaded:
|
|
raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.")
|
|
|
|
if not self._lora_initialized:
|
|
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
|
|
|
|
qlen = hidden_states.shape[0]
|
|
if qlen > self.chunked_prefill_size:
|
|
raise ValueError(
|
|
f"qlen ({qlen}) exceeds chunked_prefill_size ({self.chunked_prefill_size}). "
|
|
"Increase chunked_prefill_size or reduce qlen to avoid buffer overrun."
|
|
)
|
|
if expert_ids.shape[0] != qlen or expert_ids.shape[1] != self.num_experts_per_tok:
|
|
raise ValueError(
|
|
f"expert_ids shape {tuple(expert_ids.shape)} must be " f"({qlen}, {self.num_experts_per_tok})."
|
|
)
|
|
if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok:
|
|
raise ValueError(f"weights shape {tuple(weights.shape)} must be " f"({qlen}, {self.num_experts_per_tok}).")
|
|
|
|
# Get or create buffer (always bf16 for computation)
|
|
buffer = KExpertsSFTBuffer.get_buffer(
|
|
qlen=qlen,
|
|
hidden_size=self.hidden_size,
|
|
moe_intermediate_size=self.moe_intermediate_size,
|
|
num_experts=self.num_experts,
|
|
num_experts_per_tok=self.num_experts_per_tok,
|
|
lora_rank=self.lora_rank,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Copy input data directly to pinned CPU buffers (works for both CPU and GPU tensors)
|
|
# For GPU tensors: this is a single GPU->pinned copy (faster than GPU->CPU->pinned)
|
|
# For CPU tensors: this is a CPU->pinned copy
|
|
input_device = hidden_states.device
|
|
buffer.input_cpu.copy_(hidden_states.to(torch.bfloat16), non_blocking=True)
|
|
buffer.expert_ids_cpu.copy_(expert_ids.to(torch.int64), non_blocking=True)
|
|
buffer.weights_cpu.copy_(weights.to(torch.float32), non_blocking=True)
|
|
buffer.bsz_tensor[0] = qlen
|
|
|
|
# Synchronize CUDA stream if input was on GPU to ensure data has arrived
|
|
if input_device.type == "cuda":
|
|
torch.cuda.synchronize(input_device)
|
|
|
|
# Submit forward task
|
|
self.cpu_infer.submit(
|
|
self.moe.forward_sft_task(
|
|
buffer.bsz_tensor.data_ptr(),
|
|
self.num_experts_per_tok,
|
|
buffer.expert_ids_cpu.data_ptr(),
|
|
buffer.weights_cpu.data_ptr(),
|
|
buffer.input_cpu.data_ptr(),
|
|
buffer.output_cpu.data_ptr(),
|
|
save_for_backward,
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
|
|
# Track cache depth
|
|
if save_for_backward:
|
|
self._cache_depth += 1
|
|
if self._cache_depth > self.max_cache_depth:
|
|
raise RuntimeError(
|
|
f"Forward cache full (depth={self._cache_depth}, max={self.max_cache_depth}). "
|
|
"Call backward() to release cache entries."
|
|
)
|
|
|
|
# Return output: if output_device specified, copy directly to that device
|
|
# This avoids clone() when transferring to GPU (pinned->GPU is fast)
|
|
if output_device is not None:
|
|
return buffer.output_cpu.to(device=output_device, non_blocking=True)
|
|
else:
|
|
# No output device specified: clone for safety (legacy behavior)
|
|
return buffer.output_cpu.clone()
|
|
|
|
def backward(
|
|
self,
|
|
grad_output: torch.Tensor,
|
|
lora_params: Optional[Dict[str, torch.nn.Parameter]] = None,
|
|
output_device: Optional[torch.device] = None,
|
|
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
|
|
"""
|
|
Backward pass computing gradients.
|
|
|
|
Must be called after forward_sft(save_for_backward=True).
|
|
|
|
Optimized for minimal data copying:
|
|
- Accepts GPU tensors directly
|
|
- Returns directly to output_device without intermediate clone
|
|
- LoRA gradients are returned in grad_loras dict (no clone needed)
|
|
|
|
Args:
|
|
grad_output: Gradient from upstream [qlen, hidden_size] (any device, will be converted to bf16)
|
|
lora_params: Optional dict of LoRA parameters (kept for compatibility).
|
|
If provided, gradients are still returned in grad_loras.
|
|
Keys: gate_lora_a, gate_lora_b, up_lora_a, up_lora_b, down_lora_a, down_lora_b
|
|
output_device: Target device for grad_input output (None = clone CPU tensors for safety)
|
|
|
|
Returns:
|
|
grad_input: Input gradient [qlen, hidden_size]
|
|
grad_loras: LoRA gradients dict (e.g., grad_gate_lora_a, grad_gate_lora_b, ...)
|
|
grad_weights: Routing weights gradient [qlen, num_experts_per_tok]
|
|
"""
|
|
if self._cache_depth <= 0:
|
|
raise RuntimeError("No forward cache available. Call forward_sft(save_for_backward=True) first.")
|
|
|
|
qlen = grad_output.shape[0]
|
|
|
|
# Get buffer (should exist from forward pass, always bf16)
|
|
buffer = KExpertsSFTBuffer.get_buffer(
|
|
qlen=qlen,
|
|
hidden_size=self.hidden_size,
|
|
moe_intermediate_size=self.moe_intermediate_size,
|
|
num_experts=self.num_experts,
|
|
num_experts_per_tok=self.num_experts_per_tok,
|
|
lora_rank=self.lora_rank,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Copy gradient directly to pinned CPU buffer (works for both CPU and GPU tensors)
|
|
input_device = grad_output.device
|
|
buffer.grad_output_cpu.copy_(grad_output.to(torch.bfloat16), non_blocking=True)
|
|
|
|
# Zero out gradient buffers
|
|
buffer.grad_input_cpu.zero_()
|
|
buffer.grad_weights.zero_()
|
|
|
|
# Zero out LoRA gradient buffers (C++ backward accumulates into these)
|
|
if self.grad_gate_lora_a is not None:
|
|
self.grad_gate_lora_a.zero_()
|
|
self.grad_gate_lora_b.zero_()
|
|
self.grad_up_lora_a.zero_()
|
|
self.grad_up_lora_b.zero_()
|
|
self.grad_down_lora_a.zero_()
|
|
self.grad_down_lora_b.zero_()
|
|
|
|
# Synchronize CUDA stream if input was on GPU to ensure data has arrived
|
|
if input_device.type == "cuda":
|
|
torch.cuda.synchronize(input_device)
|
|
|
|
# Submit backward task
|
|
self.cpu_infer.submit(
|
|
self.moe.backward_task(
|
|
buffer.grad_output_cpu.data_ptr(),
|
|
buffer.grad_input_cpu.data_ptr(),
|
|
self.grad_gate_lora_a.data_ptr(),
|
|
self.grad_gate_lora_b.data_ptr(),
|
|
self.grad_up_lora_a.data_ptr(),
|
|
self.grad_up_lora_b.data_ptr(),
|
|
self.grad_down_lora_a.data_ptr(),
|
|
self.grad_down_lora_b.data_ptr(),
|
|
buffer.grad_weights.data_ptr(),
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
|
|
# # Debug: print LoRA weights and computed gradients
|
|
# print(f"\033[33m[AMX_SFT DEBUG] layer={self.layer_idx} backward "
|
|
# f"lora_a weights: gate={self.gate_lora_a.float().norm().item():.6f} "
|
|
# f"up={self.up_lora_a.float().norm().item():.6f} "
|
|
# f"down={self.down_lora_a.float().norm().item():.6f} | "
|
|
# f"lora_b weights: gate={self.gate_lora_b.float().norm().item():.6f} "
|
|
# f"up={self.up_lora_b.float().norm().item():.6f} "
|
|
# f"down={self.down_lora_b.float().norm().item():.6f} | "
|
|
# f"grad_a: gate={self.grad_gate_lora_a.float().norm().item():.6f} "
|
|
# f"up={self.grad_up_lora_a.float().norm().item():.6f} "
|
|
# f"down={self.grad_down_lora_a.float().norm().item():.6f} | "
|
|
# f"grad_b: gate={self.grad_gate_lora_b.float().norm().item():.6f} "
|
|
# f"up={self.grad_up_lora_b.float().norm().item():.6f} "
|
|
# f"down={self.grad_down_lora_b.float().norm().item():.6f}"
|
|
# f"\033[0m", flush=True)
|
|
|
|
# Decrease cache depth
|
|
self._cache_depth -= 1
|
|
|
|
# Return gradients: if output_device specified, transfer grad_input directly
|
|
if output_device is not None:
|
|
grad_input = buffer.grad_input_cpu.to(device=output_device, non_blocking=True)
|
|
grad_weights = buffer.grad_weights.to(device=output_device, non_blocking=True)
|
|
else:
|
|
# No output device: clone for safety (legacy behavior)
|
|
grad_input = buffer.grad_input_cpu.clone()
|
|
grad_weights = buffer.grad_weights.clone()
|
|
|
|
grad_loras = {
|
|
"grad_gate_lora_a": self.grad_gate_lora_a,
|
|
"grad_gate_lora_b": self.grad_gate_lora_b,
|
|
"grad_up_lora_a": self.grad_up_lora_a,
|
|
"grad_up_lora_b": self.grad_up_lora_b,
|
|
"grad_down_lora_a": self.grad_down_lora_a,
|
|
"grad_down_lora_b": self.grad_down_lora_b,
|
|
}
|
|
|
|
return grad_input, grad_loras, grad_weights
|
|
|
|
def update_lora_weights(self) -> None:
|
|
"""
|
|
Sync LoRA weights to C++ backend.
|
|
|
|
Call this after using an external optimizer to update LoRA weights.
|
|
This is needed because TP mode partitions weights internally.
|
|
|
|
Typical usage:
|
|
# 1. Forward + backward
|
|
output = wrapper.forward_sft(input, expert_ids, weights)
|
|
grad_input, grad_loras = wrapper.backward(grad_output)
|
|
|
|
# 2. Update LoRA weights with optimizer
|
|
optimizer.step()
|
|
|
|
# 3. Sync to C++
|
|
wrapper.update_lora_weights()
|
|
"""
|
|
if not self._weights_loaded:
|
|
raise RuntimeError("Weights not loaded. Call load_weights() first.")
|
|
|
|
if not self._lora_initialized:
|
|
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
|
|
|
|
# Submit update task
|
|
self.cpu_infer.submit(
|
|
self.moe.update_lora_weights_task(
|
|
self.gate_lora_a.data_ptr(),
|
|
self.gate_lora_b.data_ptr(),
|
|
self.up_lora_a.data_ptr(),
|
|
self.up_lora_b.data_ptr(),
|
|
self.down_lora_a.data_ptr(),
|
|
self.down_lora_b.data_ptr(),
|
|
)
|
|
)
|
|
self.cpu_infer.sync()
|
|
|
|
def submit_forward_sft(
|
|
self,
|
|
hidden_states: torch.Tensor,
|
|
expert_ids: torch.Tensor,
|
|
weights: torch.Tensor,
|
|
save_for_backward: bool = True,
|
|
) -> None:
|
|
"""
|
|
Submit SFT forward pass asynchronously (non-blocking).
|
|
|
|
This method submits the CPU MoE computation without waiting for completion,
|
|
allowing GPU computation (shared_experts, lora_experts) to proceed in parallel.
|
|
|
|
Must be followed by sync_forward_sft() to retrieve results.
|
|
|
|
Optimized: accepts GPU tensors directly, copies to pinned buffer in one step.
|
|
|
|
Args:
|
|
hidden_states: Input hidden states [qlen, hidden_size] (any device, will be converted to bf16)
|
|
expert_ids: Expert IDs [qlen, num_experts_per_tok] (any device, will be converted to int64)
|
|
weights: Expert weights [qlen, num_experts_per_tok] (any device, will be converted to float32)
|
|
save_for_backward: Whether to save activations for backward pass
|
|
"""
|
|
if not self._weights_loaded:
|
|
raise RuntimeError("Weights not loaded. Call load_weights() or load_weights_from_tensors() first.")
|
|
|
|
if not self._lora_initialized:
|
|
raise RuntimeError("LoRA weights not initialized. Call init_lora_weights() first.")
|
|
|
|
qlen = hidden_states.shape[0]
|
|
if qlen > self.chunked_prefill_size:
|
|
raise ValueError(
|
|
f"qlen ({qlen}) exceeds chunked_prefill_size ({self.chunked_prefill_size}). "
|
|
"Increase chunked_prefill_size or reduce qlen to avoid buffer overrun."
|
|
)
|
|
if expert_ids.shape[0] != qlen or expert_ids.shape[1] != self.num_experts_per_tok:
|
|
raise ValueError(
|
|
f"expert_ids shape {tuple(expert_ids.shape)} must be " f"({qlen}, {self.num_experts_per_tok})."
|
|
)
|
|
if weights.shape[0] != qlen or weights.shape[1] != self.num_experts_per_tok:
|
|
raise ValueError(f"weights shape {tuple(weights.shape)} must be " f"({qlen}, {self.num_experts_per_tok}).")
|
|
|
|
# Get or create buffer (always bf16)
|
|
buffer = KExpertsSFTBuffer.get_buffer(
|
|
qlen=qlen,
|
|
hidden_size=self.hidden_size,
|
|
moe_intermediate_size=self.moe_intermediate_size,
|
|
num_experts=self.num_experts,
|
|
num_experts_per_tok=self.num_experts_per_tok,
|
|
lora_rank=self.lora_rank,
|
|
dtype=torch.bfloat16,
|
|
)
|
|
|
|
# Copy input data directly to pinned CPU buffers (works for both CPU and GPU tensors)
|
|
input_device = hidden_states.device
|
|
buffer.input_cpu.copy_(hidden_states.to(torch.bfloat16), non_blocking=True)
|
|
buffer.expert_ids_cpu.copy_(expert_ids.to(torch.int64), non_blocking=True)
|
|
buffer.weights_cpu.copy_(weights.to(torch.float32), non_blocking=True)
|
|
buffer.bsz_tensor[0] = qlen
|
|
|
|
# Synchronize CUDA stream if input was on GPU to ensure data has arrived
|
|
if input_device.type == "cuda":
|
|
torch.cuda.synchronize(input_device)
|
|
|
|
# Store buffer reference and save_for_backward flag for sync_forward_sft
|
|
self._pending_buffer = buffer
|
|
self._pending_save_for_backward = save_for_backward
|
|
self._pending_qlen = qlen
|
|
|
|
# Submit forward task (non-blocking)
|
|
self.cpu_infer.submit(
|
|
self.moe.forward_sft_task(
|
|
buffer.bsz_tensor.data_ptr(),
|
|
self.num_experts_per_tok,
|
|
buffer.expert_ids_cpu.data_ptr(),
|
|
buffer.weights_cpu.data_ptr(),
|
|
buffer.input_cpu.data_ptr(),
|
|
buffer.output_cpu.data_ptr(),
|
|
save_for_backward,
|
|
)
|
|
)
|
|
|
|
def sync_forward_sft(self, output_device: Optional[torch.device] = None) -> torch.Tensor:
|
|
"""
|
|
Synchronize and retrieve SFT forward results.
|
|
|
|
Must be called after submit_forward_sft().
|
|
|
|
Args:
|
|
output_device: Target device for output (None = clone CPU tensor for safety)
|
|
|
|
Returns:
|
|
Output hidden states [qlen, hidden_size]
|
|
"""
|
|
if not hasattr(self, "_pending_buffer") or self._pending_buffer is None:
|
|
raise RuntimeError("No pending forward. Call submit_forward_sft() first.")
|
|
|
|
# Wait for completion
|
|
self.cpu_infer.sync()
|
|
|
|
buffer = self._pending_buffer
|
|
save_for_backward = self._pending_save_for_backward
|
|
|
|
# Track cache depth
|
|
if save_for_backward:
|
|
self._cache_depth += 1
|
|
if self._cache_depth > self.max_cache_depth:
|
|
raise RuntimeError(
|
|
f"Forward cache full (depth={self._cache_depth}, max={self.max_cache_depth}). "
|
|
"Call backward() to release cache entries."
|
|
)
|
|
|
|
# Clear pending state
|
|
self._pending_buffer = None
|
|
self._pending_save_for_backward = None
|
|
self._pending_qlen = None
|
|
|
|
# Return output: if output_device specified, transfer directly
|
|
if output_device is not None:
|
|
return buffer.output_cpu.to(device=output_device, non_blocking=True)
|
|
else:
|
|
return buffer.output_cpu.clone()
|