mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-04-20 14:29:32 +00:00
add lora-sglang with KT
This commit is contained in:
162
convert_lora.py
Normal file
162
convert_lora.py
Normal file
@@ -0,0 +1,162 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Convert PEFT format LoRA adapter to SGLang format.
|
||||
|
||||
This script copies an entire LoRA adapter directory and converts the safetensors
|
||||
weight files from PEFT format to SGLang format by removing unnecessary prefixes.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
|
||||
def convert_lora_adapter(adapter_path: str, output_path: str = None, verbose: bool = True):
|
||||
"""
|
||||
Convert PEFT format LoRA adapter to SGLang format.
|
||||
|
||||
Args:
|
||||
adapter_path: Path to the input PEFT adapter directory
|
||||
output_path: Path to the output SGLang adapter directory (optional)
|
||||
If None, defaults to {adapter_path}_converted
|
||||
verbose: Whether to print detailed conversion information
|
||||
"""
|
||||
adapter_path = Path(adapter_path)
|
||||
|
||||
# Set default output path if not provided
|
||||
if output_path is None:
|
||||
output_path = Path(str(adapter_path) + "_converted")
|
||||
else:
|
||||
output_path = Path(output_path)
|
||||
|
||||
if not adapter_path.exists():
|
||||
raise FileNotFoundError(f"Adapter path not found: {adapter_path}")
|
||||
|
||||
if not adapter_path.is_dir():
|
||||
raise ValueError(f"Adapter path must be a directory: {adapter_path}")
|
||||
|
||||
# Create output directory
|
||||
output_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if verbose:
|
||||
print("=" * 80)
|
||||
print("Converting PEFT LoRA Adapter to SGLang Format")
|
||||
print("=" * 80)
|
||||
print(f"\nInput: {adapter_path}")
|
||||
print(f"Output: {output_path}\n")
|
||||
|
||||
# Process all files in the adapter directory
|
||||
converted_files = []
|
||||
copied_files = []
|
||||
|
||||
for file_path in adapter_path.iterdir():
|
||||
if file_path.is_file():
|
||||
output_file = output_path / file_path.name
|
||||
|
||||
# Convert safetensors files
|
||||
if file_path.suffix == '.safetensors':
|
||||
if verbose:
|
||||
print(f"Converting: {file_path.name}")
|
||||
|
||||
# Load PEFT weights
|
||||
state_dict = load_file(str(file_path))
|
||||
|
||||
# Convert keys by removing PEFT prefixes
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
# Remove 'base_model.model.' prefix
|
||||
new_key = key.replace("base_model.model.", "")
|
||||
# Remove 'orig_module.' occurrences
|
||||
new_key = new_key.replace(".orig_module", "")
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
if verbose and key != new_key:
|
||||
print(f" {key}")
|
||||
print(f" -> {new_key}")
|
||||
|
||||
# Save converted weights
|
||||
save_file(new_state_dict, str(output_file))
|
||||
converted_files.append(file_path.name)
|
||||
|
||||
if verbose:
|
||||
print(f" Saved to: {output_file}\n")
|
||||
|
||||
# Copy other files as-is
|
||||
else:
|
||||
if verbose:
|
||||
print(f"Copying: {file_path.name}")
|
||||
shutil.copy2(str(file_path), str(output_file))
|
||||
copied_files.append(file_path.name)
|
||||
|
||||
# Print summary
|
||||
if verbose:
|
||||
print("=" * 80)
|
||||
print("Conversion Summary")
|
||||
print("=" * 80)
|
||||
print(f"\nConverted files ({len(converted_files)}):")
|
||||
for f in converted_files:
|
||||
print(f" - {f}")
|
||||
|
||||
print(f"\nCopied files ({len(copied_files)}):")
|
||||
for f in copied_files:
|
||||
print(f" - {f}")
|
||||
|
||||
print(f"\nTotal: {len(converted_files) + len(copied_files)} files processed")
|
||||
print(f"Output directory: {output_path}\n")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert PEFT format LoRA adapter to SGLang format",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Use default output path (adds _converted suffix)
|
||||
python convert_lora.py /path/to/adapter
|
||||
|
||||
# Specify custom output path
|
||||
python convert_lora.py /path/to/adapter /path/to/output
|
||||
|
||||
# Quiet mode (minimal output)
|
||||
python convert_lora.py /path/to/adapter --quiet
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"adapter_path",
|
||||
type=str,
|
||||
help="Path to the input PEFT adapter directory"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"output_path",
|
||||
type=str,
|
||||
nargs="?",
|
||||
default=None,
|
||||
help="Path to the output SGLang adapter directory (default: {adapter_path}_converted)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"-q", "--quiet",
|
||||
action="store_true",
|
||||
help="Quiet mode - minimal output"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
convert_lora_adapter(
|
||||
adapter_path=args.adapter_path,
|
||||
output_path=args.output_path,
|
||||
verbose=not args.quiet
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"\nError: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -14,6 +14,7 @@ from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
@@ -362,6 +363,67 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
return B
|
||||
|
||||
|
||||
class ReplicatedLinearWithLoRA(BaseLayerWithLoRA):
|
||||
"""LoRA support for ReplicatedLinear layers.
|
||||
|
||||
ReplicatedLinear is used in DeepSeek-V2 for layers like kv_a_proj_with_mqa
|
||||
where weights are replicated across all TP ranks (not sharded).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_layer: ReplicatedLinear,
|
||||
lora_backend: BaseLoRABackend,
|
||||
) -> None:
|
||||
super().__init__(base_layer, lora_backend)
|
||||
# For replicated linear, output is not sharded, so we use a simple offset
|
||||
self.output_offset = torch.tensor(
|
||||
[0, self.base_layer.output_size],
|
||||
dtype=torch.int32,
|
||||
device=next(self.base_layer.parameters()).device,
|
||||
)
|
||||
|
||||
def set_lora_info(
|
||||
self,
|
||||
A_buffer: torch.Tensor,
|
||||
B_buffer: torch.Tensor,
|
||||
):
|
||||
self.set_lora = True
|
||||
self.A_buffer = A_buffer
|
||||
self.B_buffer = B_buffer
|
||||
|
||||
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
||||
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
||||
lora_output = self.lora_backend.run_lora_b_sgemm(
|
||||
x=lora_a_output,
|
||||
weights=self.B_buffer,
|
||||
output_offset=self.output_offset,
|
||||
base_output=base_output,
|
||||
)
|
||||
return lora_output
|
||||
|
||||
def forward(self, input_: torch.Tensor):
|
||||
# duplicate the logic in ReplicatedLinear
|
||||
bias = self.base_layer.bias if not self.base_layer.skip_bias_add else None
|
||||
output = self.base_layer.quant_method.apply(
|
||||
self.base_layer, input_, bias
|
||||
)
|
||||
|
||||
if self.set_lora:
|
||||
output = self.apply_lora(output, input_)
|
||||
|
||||
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
||||
return output, output_bias
|
||||
|
||||
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
||||
# No slicing needed for replicated layer - all ranks have the same input
|
||||
return A
|
||||
|
||||
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
||||
# No slicing needed for replicated layer - all ranks produce the same output
|
||||
return B
|
||||
|
||||
|
||||
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -586,6 +648,7 @@ def get_lora_layer(
|
||||
QKVParallelLinear: QKVParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinear: MergedColumnParallelLinearWithLoRA,
|
||||
ColumnParallelLinear: ColumnParallelLinearWithLoRA,
|
||||
ReplicatedLinear: ReplicatedLinearWithLoRA,
|
||||
RowParallelLinear: RowParallelLinearWithLoRA,
|
||||
}
|
||||
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
||||
|
||||
@@ -132,9 +132,19 @@ class LoRAAdapter(nn.Module):
|
||||
target_module.add("v_proj")
|
||||
if "qkv_proj" in weight_name:
|
||||
target_module.add("qkv_proj")
|
||||
# Check for DeepSeek-V2 MLA architecture modules
|
||||
if "kv_a_proj_with_mqa" in weight_name or "kv_b_proj" in weight_name:
|
||||
target_module.add("deepseek_v2_mla")
|
||||
if len(target_module) == 0:
|
||||
return
|
||||
|
||||
# Check if this is DeepSeek-V2 or V3 with MLA architecture
|
||||
# These models use q_proj + kv_a_proj_with_mqa + kv_b_proj instead of traditional q/k/v
|
||||
is_deepseek_mla = "deepseek_v2_mla" in target_module or (
|
||||
hasattr(self.base_hf_config, "model_type")
|
||||
and self.base_hf_config.model_type in ["deepseek_v2", "deepseek_v3"]
|
||||
)
|
||||
|
||||
for weight_name in weight_names:
|
||||
# We assume every lora adaptor should contain lora modules for q_proj
|
||||
if "q_proj" in weight_name:
|
||||
@@ -143,6 +153,14 @@ class LoRAAdapter(nn.Module):
|
||||
v_name = weight_name.replace("q_proj", "v_proj")
|
||||
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
||||
|
||||
# For DeepSeek-V2/V3 MLA architecture, q_proj is standalone (ColumnParallelLinear)
|
||||
# Do NOT rename or merge - keep q_proj as is
|
||||
# The MLA architecture uses separate kv_a_proj_with_mqa and kv_b_proj for K/V
|
||||
if is_deepseek_mla:
|
||||
# Keep q_proj unchanged
|
||||
continue
|
||||
|
||||
# Traditional architecture: merge q/k/v
|
||||
# If k_proj doesn't have lora, initialize it to zero
|
||||
k_proj_weight = (
|
||||
weights[k_name]
|
||||
|
||||
@@ -69,7 +69,10 @@ def get_hidden_dim(
|
||||
head_dim = getattr(
|
||||
config, "head_dim", config.hidden_size // config.num_attention_heads
|
||||
)
|
||||
if module_name == "qkv_proj":
|
||||
if module_name == "q_proj":
|
||||
# For DeepSeek-V2 MLA and similar architectures where q_proj is not merged
|
||||
return config.hidden_size, head_dim * config.num_attention_heads
|
||||
elif module_name == "qkv_proj":
|
||||
return config.hidden_size, head_dim * (
|
||||
config.num_attention_heads + config.num_key_value_heads * 2
|
||||
)
|
||||
@@ -100,9 +103,17 @@ def get_normalized_target_modules(
|
||||
"""
|
||||
Mapping a list of target module name to names of the normalized LoRA weights.
|
||||
Handles both base module names (e.g., "gate_proj") and prefixed module names (e.g., "feed_forward.gate_proj").
|
||||
|
||||
For DeepSeek-V2/V3 MLA architecture, q_proj is kept separate (not merged into qkv_proj).
|
||||
"""
|
||||
# Check if this is DeepSeek-V2/V3 MLA architecture
|
||||
target_modules_list = list(target_modules)
|
||||
is_deepseek_mla = any(
|
||||
"kv_a_proj_with_mqa" in name or "kv_b_proj" in name
|
||||
for name in target_modules_list
|
||||
)
|
||||
|
||||
params_mapping = {
|
||||
"q_proj": "qkv_proj",
|
||||
"k_proj": "qkv_proj",
|
||||
"v_proj": "qkv_proj",
|
||||
"gate_proj": "gate_up_proj",
|
||||
@@ -115,8 +126,12 @@ def get_normalized_target_modules(
|
||||
"output": "lm_head",
|
||||
}
|
||||
|
||||
# For non-MLA architectures, q_proj should also be mapped to qkv_proj
|
||||
if not is_deepseek_mla:
|
||||
params_mapping["q_proj"] = "qkv_proj"
|
||||
|
||||
result = set()
|
||||
for name in target_modules:
|
||||
for name in target_modules_list:
|
||||
base_name = name.split(".")[-1]
|
||||
normalized_name = params_mapping.get(base_name, base_name)
|
||||
result.add(normalized_name)
|
||||
|
||||
@@ -3376,6 +3376,66 @@ class DeepseekV2ForCausalLM(nn.Module):
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.embed_tokens
|
||||
|
||||
def get_hidden_dim(self, module_name: str, layer_idx: int):
|
||||
"""
|
||||
Get input and output dimensions for LoRA modules in DeepSeek-V2 MLA architecture.
|
||||
|
||||
DeepSeek-V2 uses MLA (Multi-head Latent Attention) with:
|
||||
- q_proj: standard query projection
|
||||
- kv_a_proj_with_mqa: compresses KV to latent space
|
||||
- kv_b_proj: expands from latent space to K and V
|
||||
- o_proj: output projection
|
||||
"""
|
||||
config = self.config
|
||||
|
||||
# MLA-specific modules
|
||||
if module_name == "q_proj" or module_name == "qkv_proj":
|
||||
# Q projection (renamed to qkv_proj in LoRA system): hidden_size -> num_heads * qk_head_dim
|
||||
# Note: For DeepSeek-V2, this is only Q, not QKV
|
||||
qk_head_dim = config.qk_nope_head_dim + config.qk_rope_head_dim
|
||||
return (
|
||||
config.hidden_size,
|
||||
config.num_attention_heads * qk_head_dim,
|
||||
)
|
||||
elif module_name == "kv_a_proj_with_mqa":
|
||||
# KV compression: hidden_size -> kv_lora_rank + qk_rope_head_dim
|
||||
return (
|
||||
config.hidden_size,
|
||||
config.kv_lora_rank + config.qk_rope_head_dim,
|
||||
)
|
||||
elif module_name == "kv_b_proj":
|
||||
# KV expansion: kv_lora_rank -> num_heads * (qk_nope_head_dim + v_head_dim)
|
||||
return (
|
||||
config.kv_lora_rank,
|
||||
config.num_attention_heads * (config.qk_nope_head_dim + config.v_head_dim),
|
||||
)
|
||||
elif module_name == "o_proj":
|
||||
# Output projection: num_heads * v_head_dim -> hidden_size
|
||||
return (
|
||||
config.num_attention_heads * config.v_head_dim,
|
||||
config.hidden_size,
|
||||
)
|
||||
# MLP modules (MoE or shared experts)
|
||||
elif module_name == "gate_up_proj" or module_name == "down_proj":
|
||||
# Determine the correct intermediate size based on layer structure
|
||||
# Some layers have regular MLP, others have MoE with SharedExperts
|
||||
intermediate_size = config.intermediate_size
|
||||
|
||||
# Check if this layer has SharedExperts (which use different intermediate_size)
|
||||
if hasattr(self, "model") and layer_idx < len(self.model.layers):
|
||||
mlp = self.model.layers[layer_idx].mlp
|
||||
# DeepseekV2MoE has shared_experts attribute
|
||||
if hasattr(mlp, "shared_experts"):
|
||||
# SharedExperts use moe_intermediate_size * n_shared_experts
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
|
||||
if module_name == "gate_up_proj":
|
||||
return config.hidden_size, intermediate_size * 2
|
||||
else: # down_proj
|
||||
return intermediate_size, config.hidden_size
|
||||
else:
|
||||
raise NotImplementedError(f"Module {module_name} not supported for DeepSeek-V2 LoRA")
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
|
||||
76
scripts/lora_demo/chat.py
Normal file
76
scripts/lora_demo/chat.py
Normal file
@@ -0,0 +1,76 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple script to chat with SGLang server
|
||||
"""
|
||||
import requests
|
||||
import json
|
||||
|
||||
# Server configuration
|
||||
SERVER_URL = "http://localhost:8173/v1/chat/completions"
|
||||
MODEL_NAME = "DeepSeekV2-Lite-west"
|
||||
LORA_PATH = "/mnt/data/lpl/test_adapter/Kllama_deepseekV2_WEST/checkpoint-1321_converted"
|
||||
|
||||
def chat(message, use_lora=False):
|
||||
"""Send a chat message to the server"""
|
||||
payload = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": [
|
||||
{"role": "user", "content": message}
|
||||
],
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 512
|
||||
}
|
||||
|
||||
# Add LoRA name if requested
|
||||
# Use the lora name defined in --lora-paths (e.g., "lora0")
|
||||
if use_lora:
|
||||
payload["lora_path"] = "lora0" # Use the lora name, not the full path
|
||||
|
||||
try:
|
||||
response = requests.post(SERVER_URL, json=payload, timeout=60)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
assistant_message = result["choices"][0]["message"]["content"]
|
||||
return assistant_message
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
return f"Error: {e}"
|
||||
|
||||
def interactive_chat():
|
||||
"""Interactive chat loop"""
|
||||
print("=== SGLang Server Chat ===")
|
||||
print(f"Server: {SERVER_URL}")
|
||||
print(f"Model: {MODEL_NAME}")
|
||||
print("\nType 'quit' or 'exit' to stop")
|
||||
print("Type 'lora' to toggle LoRA adapter\n")
|
||||
|
||||
use_lora = False
|
||||
|
||||
while True:
|
||||
user_input = input("\nYou: ").strip()
|
||||
|
||||
if user_input.lower() in ['quit', 'exit']:
|
||||
print("Goodbye!")
|
||||
break
|
||||
|
||||
if user_input.lower() == 'lora':
|
||||
use_lora = not use_lora
|
||||
print(f"LoRA adapter: {'enabled' if use_lora else 'disabled'}")
|
||||
continue
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
print("\nAssistant: ", end="", flush=True)
|
||||
response = chat(user_input, use_lora=use_lora)
|
||||
print(response)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Simple test
|
||||
print("Testing server connection...")
|
||||
response = chat("Hello")
|
||||
print(f"Response: {response}\n")
|
||||
|
||||
# Start interactive chat
|
||||
interactive_chat()
|
||||
Reference in New Issue
Block a user