Files
ktransformers/kt-kernel/python/sft/__init__.py
mrhaoxx 07fd9328fa refactor(sft): move SFT logic into kt_kernel.sft submodule
- Create python/sft/ with 11 modules: base, amx, arch, autograd, layer,
  lora, weights, wrapper, dist_utils, config, __init__
- Move BaseSFTMoEWrapper + buffer management into sft/base.py (template
  method pattern: subclass provides _make_forward/backward_task)
- Move AMXSFTMoEWrapper into sft/amx.py (thinner, no buffer logic)
- Move from accelerate kt_moe.py: KTMoEFunction, KTMoELayerWrapper,
  MOEArchConfig, PEFT LoRA adaptation, weight extraction, wrapping
- Add KTConfig dataclass (DeepSpeed pattern: opaque config passthrough)
- Add _get_kt_config() with old→new field name compat conversion
- Rename forward_sft→forward, submit_forward_sft→submit_forward,
  sync_forward_sft→sync_forward (Python only, C++ binding names unchanged)
- Delete dump utilities from sft_moe.hpp (-526) and moe-sft-tp.hpp (-78)
- Delete experts_sft.py and utils/amx_sft.py (moved to sft/)
- Remove SFT stubs from BaseMoEWrapper (experts_base.py)
- Lazy SFT import in __init__.py and experts.py (inference isolation)
- Delete all lifecycle/debug logging (~500 lines)

Verified: Qwen3-235B 4GPU AMXBF16 training, 3 steps loss converges.
2026-04-08 23:07:41 +08:00

84 lines
2.3 KiB
Python

# SFT (Supervised Fine-Tuning) submodule for kt-kernel
# SPDX-License-Identifier: Apache-2.0
"""
SFT training support for KT-Kernel MoE.
This submodule adds training capabilities (forward/backward, LoRA, autograd,
distributed) on top of the inference-only kt_kernel base package.
Additional dependencies beyond base kt_kernel: torch.nn, torch.distributed, peft (optional).
"""
from .config import KTConfig
from .base import BaseSFTMoEWrapper, KExpertsSFTBuffer
from .amx import AMXSFTMoEWrapper
from .arch import (
MOEArchConfig, get_moe_arch_config, get_moe_module, move_non_experts_to_gpu, get_expert_device,
KTAMXError, KTAMXNotAvailableError, KTAMXModelNotSupportedError, KTAMXConfigError,
)
from .autograd import KTMoEFunction
from .layer import KTMoELayerWrapper
from .weights import (
extract_moe_weights,
load_experts_from_checkpoint_files,
load_experts_from_kt_weight_path,
INT8ExpertWeights,
)
from .lora import (
kt_adapt_peft_lora,
get_kt_lora_params,
update_kt_lora_pointers,
sync_kt_lora_gradients,
save_lora_experts_to_adapter,
save_kt_moe_to_adapter,
load_lora_experts_from_adapter,
load_kt_moe_from_adapter,
LoRAExpertMLP,
LoRAExperts,
)
from .wrapper import (
wrap_moe_layers_with_kt_wrapper,
build_kt_device_map,
build_kt_device_map_simplified,
get_kt_loading_kwargs,
load_kt_model,
)
__all__ = [
"KTConfig",
"BaseSFTMoEWrapper",
"KExpertsSFTBuffer",
"AMXSFTMoEWrapper",
"MOEArchConfig",
"get_moe_arch_config",
"get_moe_module",
"move_non_experts_to_gpu",
"get_expert_device",
"KTAMXError",
"KTAMXNotAvailableError",
"KTAMXModelNotSupportedError",
"KTAMXConfigError",
"KTMoEFunction",
"KTMoELayerWrapper",
"extract_moe_weights",
"load_experts_from_checkpoint_files",
"load_experts_from_kt_weight_path",
"INT8ExpertWeights",
"kt_adapt_peft_lora",
"get_kt_lora_params",
"update_kt_lora_pointers",
"sync_kt_lora_gradients",
"save_lora_experts_to_adapter",
"save_kt_moe_to_adapter",
"load_lora_experts_from_adapter",
"load_kt_moe_from_adapter",
"LoRAExpertMLP",
"LoRAExperts",
"wrap_moe_layers_with_kt_wrapper",
"build_kt_device_map",
"build_kt_device_map_simplified",
"get_kt_loading_kwargs",
"load_kt_model",
]