add lora-sglang with KT

This commit is contained in:
JimmyPeilinLi
2025-12-24 12:13:16 +08:00
parent 48bcfdb039
commit 05ce752126
6 changed files with 397 additions and 3 deletions

162
convert_lora.py Normal file
View File

@@ -0,0 +1,162 @@
#!/usr/bin/env python3
"""
Convert PEFT format LoRA adapter to SGLang format.
This script copies an entire LoRA adapter directory and converts the safetensors
weight files from PEFT format to SGLang format by removing unnecessary prefixes.
"""
import argparse
import shutil
from pathlib import Path
from safetensors.torch import load_file, save_file
def convert_lora_adapter(adapter_path: str, output_path: str = None, verbose: bool = True):
"""
Convert PEFT format LoRA adapter to SGLang format.
Args:
adapter_path: Path to the input PEFT adapter directory
output_path: Path to the output SGLang adapter directory (optional)
If None, defaults to {adapter_path}_converted
verbose: Whether to print detailed conversion information
"""
adapter_path = Path(adapter_path)
# Set default output path if not provided
if output_path is None:
output_path = Path(str(adapter_path) + "_converted")
else:
output_path = Path(output_path)
if not adapter_path.exists():
raise FileNotFoundError(f"Adapter path not found: {adapter_path}")
if not adapter_path.is_dir():
raise ValueError(f"Adapter path must be a directory: {adapter_path}")
# Create output directory
output_path.mkdir(parents=True, exist_ok=True)
if verbose:
print("=" * 80)
print("Converting PEFT LoRA Adapter to SGLang Format")
print("=" * 80)
print(f"\nInput: {adapter_path}")
print(f"Output: {output_path}\n")
# Process all files in the adapter directory
converted_files = []
copied_files = []
for file_path in adapter_path.iterdir():
if file_path.is_file():
output_file = output_path / file_path.name
# Convert safetensors files
if file_path.suffix == '.safetensors':
if verbose:
print(f"Converting: {file_path.name}")
# Load PEFT weights
state_dict = load_file(str(file_path))
# Convert keys by removing PEFT prefixes
new_state_dict = {}
for key, value in state_dict.items():
# Remove 'base_model.model.' prefix
new_key = key.replace("base_model.model.", "")
# Remove 'orig_module.' occurrences
new_key = new_key.replace(".orig_module", "")
new_state_dict[new_key] = value
if verbose and key != new_key:
print(f" {key}")
print(f" -> {new_key}")
# Save converted weights
save_file(new_state_dict, str(output_file))
converted_files.append(file_path.name)
if verbose:
print(f" Saved to: {output_file}\n")
# Copy other files as-is
else:
if verbose:
print(f"Copying: {file_path.name}")
shutil.copy2(str(file_path), str(output_file))
copied_files.append(file_path.name)
# Print summary
if verbose:
print("=" * 80)
print("Conversion Summary")
print("=" * 80)
print(f"\nConverted files ({len(converted_files)}):")
for f in converted_files:
print(f" - {f}")
print(f"\nCopied files ({len(copied_files)}):")
for f in copied_files:
print(f" - {f}")
print(f"\nTotal: {len(converted_files) + len(copied_files)} files processed")
print(f"Output directory: {output_path}\n")
def main():
parser = argparse.ArgumentParser(
description="Convert PEFT format LoRA adapter to SGLang format",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Use default output path (adds _converted suffix)
python convert_lora.py /path/to/adapter
# Specify custom output path
python convert_lora.py /path/to/adapter /path/to/output
# Quiet mode (minimal output)
python convert_lora.py /path/to/adapter --quiet
"""
)
parser.add_argument(
"adapter_path",
type=str,
help="Path to the input PEFT adapter directory"
)
parser.add_argument(
"output_path",
type=str,
nargs="?",
default=None,
help="Path to the output SGLang adapter directory (default: {adapter_path}_converted)"
)
parser.add_argument(
"-q", "--quiet",
action="store_true",
help="Quiet mode - minimal output"
)
args = parser.parse_args()
try:
convert_lora_adapter(
adapter_path=args.adapter_path,
output_path=args.output_path,
verbose=not args.quiet
)
except Exception as e:
print(f"\nError: {e}")
import traceback
traceback.print_exc()
exit(1)
if __name__ == "__main__":
main()

View File

@@ -14,6 +14,7 @@ from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.vocab_parallel_embedding import (
@@ -362,6 +363,67 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
return B
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
"""LoRA support for ReplicatedLinear layers.
ReplicatedLinear is used in DeepSeek-V2 for layers like kv_a_proj_with_mqa
where weights are replicated across all TP ranks (not sharded).
"""
def __init__(
self,
base_layer: ReplicatedLinear,
lora_backend: BaseLoRABackend,
) -> None:
super().__init__(base_layer, lora_backend)
# For replicated linear, output is not sharded, so we use a simple offset
self.output_offset = torch.tensor(
[0, self.base_layer.output_size],
dtype=torch.int32,
device=next(self.base_layer.parameters()).device,
)
def set_lora_info(
self,
A_buffer: torch.Tensor,
B_buffer: torch.Tensor,
):
self.set_lora = True
self.A_buffer = A_buffer
self.B_buffer = B_buffer
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
lora_output = self.lora_backend.run_lora_b_sgemm(
x=lora_a_output,
weights=self.B_buffer,
output_offset=self.output_offset,
base_output=base_output,
)
return lora_output
def forward(self, input_: torch.Tensor):
# duplicate the logic in ReplicatedLinear
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
output = self.base_layer.quant_method.apply(
self.base_layer, input_, bias
)
if self.set_lora:
output = self.apply_lora(output, input_)
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
return output, output_bias
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
# No slicing needed for replicated layer - all ranks have the same input
return A
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
# No slicing needed for replicated layer - all ranks produce the same output
return B
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
def __init__(
self,
@@ -586,6 +648,7 @@ def get_lora_layer(
QKVParallelLinear: QKVParallelLinearWithLoRA,
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
ReplicatedLinear: ReplicatedLinearWithLoRA,
RowParallelLinear: RowParallelLinearWithLoRA,
}
for src_layer_type, lora_layer_type in supported_layer_types.items():

View File

@@ -132,9 +132,19 @@ class LoRAAdapter(nn.Module):
target_module.add("v_proj")
if "qkv_proj" in weight_name:
target_module.add("qkv_proj")
# Check for DeepSeek-V2 MLA architecture modules
if "kv_a_proj_with_mqa" in weight_name or "kv_b_proj" in weight_name:
target_module.add("deepseek_v2_mla")
if len(target_module) == 0:
return
# Check if this is DeepSeek-V2 or V3 with MLA architecture
# These models use q_proj + kv_a_proj_with_mqa + kv_b_proj instead of traditional q/k/v
is_deepseek_mla = "deepseek_v2_mla" in target_module or (
hasattr(self.base_hf_config, "model_type")
and self.base_hf_config.model_type in ["deepseek_v2", "deepseek_v3"]
)
for weight_name in weight_names:
# We assume every lora adaptor should contain lora modules for q_proj
if "q_proj" in weight_name:
@@ -143,6 +153,14 @@ class LoRAAdapter(nn.Module):
v_name = weight_name.replace("q_proj", "v_proj")
qkv_name = weight_name.replace("q_proj", "qkv_proj")
# For DeepSeek-V2/V3 MLA architecture, q_proj is standalone (ColumnParallelLinear)
# Do NOT rename or merge - keep q_proj as is
# The MLA architecture uses separate kv_a_proj_with_mqa and kv_b_proj for K/V
if is_deepseek_mla:
# Keep q_proj unchanged
continue
# Traditional architecture: merge q/k/v
# If k_proj doesn't have lora, initialize it to zero
k_proj_weight = (
weights[k_name]

View File

@@ -69,7 +69,10 @@ def get_hidden_dim(
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
if module_name == "qkv_proj":
if module_name == "q_proj":
# For DeepSeek-V2 MLA and similar architectures where q_proj is not merged
return config.hidden_size, head_dim * config.num_attention_heads
elif module_name == "qkv_proj":
return config.hidden_size, head_dim * (
config.num_attention_heads + config.num_key_value_heads * 2
)
@@ -100,9 +103,17 @@ def get_normalized_target_modules(
"""
Mapping a list of target module name to names of the normalized LoRA weights.
Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj").
For DeepSeek-V2/V3 MLA architecture, q_proj is kept separate (not merged into qkv_proj).
"""
# Check if this is DeepSeek-V2/V3 MLA architecture
target_modules_list = list(target_modules)
is_deepseek_mla = any(
"kv_a_proj_with_mqa" in name or "kv_b_proj" in name
for name in target_modules_list
)
params_mapping = {
"q_proj": "qkv_proj",
"k_proj": "qkv_proj",
"v_proj": "qkv_proj",
"gate_proj": "gate_up_proj",
@@ -115,8 +126,12 @@ def get_normalized_target_modules(
"output": "lm_head",
}
# For non-MLA architectures, q_proj should also be mapped to qkv_proj
if not is_deepseek_mla:
params_mapping["q_proj"] = "qkv_proj"
result = set()
for name in target_modules:
for name in target_modules_list:
base_name = name.split(".")[-1]
normalized_name = params_mapping.get(base_name, base_name)
result.add(normalized_name)

View File

@@ -3376,6 +3376,66 @@ class DeepseekV2ForCausalLM(nn.Module):
def get_input_embeddings(self) -> nn.Embedding:
return self.model.embed_tokens
def get_hidden_dim(self, module_name: str, layer_idx: int):
"""
Get input and output dimensions for LoRA modules in DeepSeek-V2 MLA architecture.
DeepSeek-V2 uses MLA (Multi-head Latent Attention) with:
- q_proj: standard query projection
- kv_a_proj_with_mqa: compresses KV to latent space
- kv_b_proj: expands from latent space to K and V
- o_proj: output projection
"""
config = self.config
# MLA-specific modules
if module_name == "q_proj" or module_name == "qkv_proj":
# Q projection (renamed to qkv_proj in LoRA system): hidden_size -> num_heads * qk_head_dim
# Note: For DeepSeek-V2, this is only Q, not QKV
qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
return (
config.hidden_size,
config.num_attention_heads * qk_head_dim,
)
elif module_name == "kv_a_proj_with_mqa":
# KV compression: hidden_size -> kv_lora_rank + qk_rope_head_dim
return (
config.hidden_size,
config.kv_lora_rank + config.qk_rope_head_dim,
)
elif module_name == "kv_b_proj":
# KV expansion: kv_lora_rank -> num_heads * (qk_nope_head_dim + v_head_dim)
return (
config.kv_lora_rank,
config.num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim),
)
elif module_name == "o_proj":
# Output projection: num_heads * v_head_dim -> hidden_size
return (
config.num_attention_heads * config.v_head_dim,
config.hidden_size,
)
# MLP modules (MoE or shared experts)
elif module_name == "gate_up_proj" or module_name == "down_proj":
# Determine the correct intermediate size based on layer structure
# Some layers have regular MLP, others have MoE with SharedExperts
intermediate_size = config.intermediate_size
# Check if this layer has SharedExperts (which use different intermediate_size)
if hasattr(self, "model") and layer_idx < len(self.model.layers):
mlp = self.model.layers[layer_idx].mlp
# DeepseekV2MoE has shared_experts attribute
if hasattr(mlp, "shared_experts"):
# SharedExperts use moe_intermediate_size * n_shared_experts
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
if module_name == "gate_up_proj":
return config.hidden_size, intermediate_size * 2
else: # down_proj
return intermediate_size, config.hidden_size
else:
raise NotImplementedError(f"Module {module_name} not supported for DeepSeek-V2 LoRA")
@torch.no_grad()
def forward(
self,

76
scripts/lora_demo/chat.py Normal file
View File

@@ -0,0 +1,76 @@
#!/usr/bin/env python3
"""
Simple script to chat with SGLang server
"""
import requests
import json
# Server configuration
SERVER_URL = "http://localhost:8173/v1/chat/completions"
MODEL_NAME = "DeepSeekV2-Lite-west"
LORA_PATH = "/mnt/data/lpl/test_adapter/Kllama_deepseekV2_WEST/checkpoint-1321_converted"
def chat(message, use_lora=False):
"""Send a chat message to the server"""
payload = {
"model": MODEL_NAME,
"messages": [
{"role": "user", "content": message}
],
"temperature": 0.7,
"max_tokens": 512
}
# Add LoRA name if requested
# Use the lora name defined in --lora-paths (e.g., "lora0")
if use_lora:
payload["lora_path"] = "lora0" # Use the lora name, not the full path
try:
response = requests.post(SERVER_URL, json=payload, timeout=60)
response.raise_for_status()
result = response.json()
assistant_message = result["choices"][0]["message"]["content"]
return assistant_message
except requests.exceptions.RequestException as e:
return f"Error: {e}"
def interactive_chat():
"""Interactive chat loop"""
print("=== SGLang Server Chat ===")
print(f"Server: {SERVER_URL}")
print(f"Model: {MODEL_NAME}")
print("\nType 'quit' or 'exit' to stop")
print("Type 'lora' to toggle LoRA adapter\n")
use_lora = False
while True:
user_input = input("\nYou: ").strip()
if user_input.lower() in ['quit', 'exit']:
print("Goodbye!")
break
if user_input.lower() == 'lora':
use_lora = not use_lora
print(f"LoRA adapter: {'enabled' if use_lora else 'disabled'}")
continue
if not user_input:
continue
print("\nAssistant: ", end="", flush=True)
response = chat(user_input, use_lora=use_lora)
print(response)
if __name__ == "__main__":
# Simple test
print("Testing server connection...")
response = chat("Hello")
print(f"Response: {response}\n")
# Start interactive chat
interactive_chat()