mirror of
https://github.com/turboderp-org/exllamav2.git
synced 2026-03-15 00:07:26 +00:00
Add basic support for Qwen3MoE
This commit is contained in:
@@ -53,6 +53,10 @@ layer_keys_mixtral_mlp = [["block_sparse_moe.experts.*.w1"],
|
||||
["block_sparse_moe.experts.*.w2"],
|
||||
["block_sparse_moe.experts.*.w3"],
|
||||
["block_sparse_moe.gate"]]
|
||||
layer_keys_qwen3moe_mlp = [["mlp.experts.*.gate_proj"],
|
||||
["mlp.experts.*.up_proj"],
|
||||
["mlp.experts.*.down_proj"],
|
||||
["mlp.gate"]]
|
||||
layer_keys_dbrx_mlp = [["block_sparse_moe.experts.*.v1", "block_sparse_moe.experts.v1"],
|
||||
["block_sparse_moe.experts.*.w1", "block_sparse_moe.experts.w1"],
|
||||
["block_sparse_moe.experts.*.w2", "block_sparse_moe.experts.w2"],
|
||||
@@ -441,6 +445,26 @@ class ExLlamaV2ArchParams:
|
||||
self.lm.supports_tp = True
|
||||
self.lm.default_use_qk_norm = True
|
||||
|
||||
# Qwen3MoE
|
||||
|
||||
if arch_string == "Qwen3MoeForCausalLM":
|
||||
arch_recognized = True
|
||||
self.lm.layer_keys += \
|
||||
layer_keys_llama_norms + \
|
||||
layer_keys_llama_attn + \
|
||||
layer_keys_qwen3moe_mlp
|
||||
self.lm.expect_keys += \
|
||||
expect_keys_llama
|
||||
self.lm.supports_tp = True
|
||||
self.lm.default_use_qk_norm = True
|
||||
self.lm.keys.update({
|
||||
"mlp_gate": ".mlp.experts.*.gate_proj",
|
||||
"mlp_up": ".mlp.experts.*.up_proj",
|
||||
"mlp_down": ".mlp.experts.*.down_proj",
|
||||
"mlp_expert_gate": ".mlp.gate"
|
||||
})
|
||||
self.lm.is_moe = True
|
||||
|
||||
# Qwen2-VL (2, 2.5)
|
||||
|
||||
if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:
|
||||
|
||||
@@ -319,9 +319,12 @@ class ExLlamaV2Config:
|
||||
default_intermediate_size,
|
||||
opt_subkey = "text_config",
|
||||
)
|
||||
self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None)
|
||||
self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts", "num_experts"], None)
|
||||
self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None)
|
||||
|
||||
if self.arch.lm.is_moe:
|
||||
self.intermediate_size = read(read_config, int, ["moe_intermediate_size"], self.intermediate_size)
|
||||
|
||||
# Logit/embedding/residual scale
|
||||
|
||||
self.logit_scale = read(read_config, float, "logit_scale", 1)
|
||||
|
||||
@@ -229,7 +229,10 @@ class AdaptiveGPTQ:
|
||||
|
||||
with torch.inference_mode():
|
||||
|
||||
self.hessian /= self.num_batches
|
||||
if self.hessian is None or self.num_batches == 0:
|
||||
self.hessian = torch.eye(self.rows, device = self.device, dtype = torch.float)
|
||||
else:
|
||||
self.hessian /= self.num_batches
|
||||
diagonal = torch.diag(self.hessian)
|
||||
|
||||
# Prepare weights
|
||||
|
||||
@@ -324,9 +324,15 @@ void QMoEMLP::forward_
|
||||
// half* lora_temp
|
||||
)
|
||||
{
|
||||
if (num_experts != 4 && num_experts != 8 && num_experts != 16)
|
||||
if (rows > MAX_Q_GEMM_WEIGHTS)
|
||||
{
|
||||
printf(" ## num_experts must be 4, 8 or 16\n");
|
||||
printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS);
|
||||
DBGI(rows);
|
||||
}
|
||||
|
||||
if (num_experts != 4 && num_experts != 8 && num_experts != 16 && num_experts != 128)
|
||||
{
|
||||
printf(" ## num_experts must be 4, 8, 16 or 128\n");
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -354,36 +360,33 @@ void QMoEMLP::forward_
|
||||
&beta_,
|
||||
temp_logits, num_experts);
|
||||
|
||||
// Compute softmax filter to and normalize top-k outputs
|
||||
// Select activation kernel
|
||||
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = WARPS;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = 1;
|
||||
gridDim.y = DIVIDE(rows, WARPS);
|
||||
if (num_experts == 4)
|
||||
softmax4_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
|
||||
else if (num_experts == 8)
|
||||
softmax8_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
|
||||
else if (num_experts == 16)
|
||||
softmax16_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
|
||||
int intermediate_size = w1[0]->width;
|
||||
fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu);
|
||||
|
||||
// For small no. rows, execute all kernels but pass the routing weights. Rows with a weight of zero will skip dot
|
||||
// product accum and kernels launched with only zero-weights will exit prematurely.
|
||||
|
||||
if (rows <= MAX_Q_GEMM_WEIGHTS)
|
||||
if (num_experts == 4 || num_experts == 8 || num_experts == 16)
|
||||
{
|
||||
int intermediate_size = w1[0]->width;
|
||||
fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu);
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = WARPSIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = 1;
|
||||
gridDim.y = DIVIDE(rows, WARPSIZE);
|
||||
if (num_experts == 4)
|
||||
softmax4_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
|
||||
else if (num_experts == 8)
|
||||
softmax8_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
|
||||
else if (num_experts == 16)
|
||||
softmax16_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
|
||||
|
||||
for (int i = 0; i < num_experts; i++)
|
||||
{
|
||||
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
|
||||
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
|
||||
|
||||
// apply_loras_cuda(cublas_handle, w1_lora[i], loras, w1[i], temp_state, temp_a, lora_temp, rows);
|
||||
// apply_loras_cuda(cublas_handle, w3_lora[i], loras, w3[i], temp_state, temp_b, lora_temp, rows);
|
||||
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = THREADS_Y;
|
||||
gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1);
|
||||
@@ -391,17 +394,43 @@ void QMoEMLP::forward_
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts);
|
||||
|
||||
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true);
|
||||
|
||||
// apply_loras_cuda(cublas_handle, w2_lora[i], loras, w2[i], temp_a, x, lora_temp, rows);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Gather larger number of rows in separate batches according to which experts they trigger, evaluate each MLP
|
||||
// only on the affected rows and scale by routing weights while adding back directly onto the residual hidden state
|
||||
// For very large number of experts (Qwen3 etc.) copy to CPU, synchronize and only launch top K experts. This is
|
||||
// not optimal but the kernel launch overhead is very severe otherwise. Really needs a graph
|
||||
|
||||
else
|
||||
else if (num_experts == 128)
|
||||
{
|
||||
printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS);
|
||||
DBGI(rows);
|
||||
dim3 blockDim, gridDim;
|
||||
blockDim.x = WARPSIZE;
|
||||
blockDim.y = 1;
|
||||
gridDim.x = 1;
|
||||
gridDim.y = DIVIDE(rows, WARPSIZE);
|
||||
softmax128_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
|
||||
|
||||
half* h_logits;
|
||||
h_logits = (half*) malloc(128 * sizeof(half));
|
||||
cudaMemcpyAsync(h_logits, temp_logits, 128 * sizeof(half), cudaMemcpyDeviceToHost, stream);
|
||||
cudaStreamSynchronize(stream);
|
||||
|
||||
for (int i = 0; i < num_experts; i++)
|
||||
{
|
||||
uint16_t w = __half_as_ushort(h_logits[i]);
|
||||
if (!w) continue;
|
||||
|
||||
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
|
||||
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
|
||||
|
||||
blockDim.x = THREADS_X;
|
||||
blockDim.y = THREADS_Y;
|
||||
gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1);
|
||||
gridDim.y = DIVIDE(rows, THREADS_Y);
|
||||
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts);
|
||||
|
||||
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true);
|
||||
}
|
||||
|
||||
free(h_logits);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
|
||||
#define WARPS 32
|
||||
#define WARPSIZE 32
|
||||
|
||||
__global__ void softmax16_topk_norm_kernel
|
||||
(
|
||||
@@ -8,7 +8,7 @@ __global__ void softmax16_topk_norm_kernel
|
||||
const int topk
|
||||
)
|
||||
{
|
||||
int row = blockIdx.y * WARPS + threadIdx.x;
|
||||
int row = blockIdx.y * WARPSIZE + threadIdx.x;
|
||||
if (row >= rows) return;
|
||||
|
||||
// Softmax
|
||||
@@ -122,7 +122,7 @@ __global__ void softmax8_topk_norm_kernel
|
||||
const int topk
|
||||
)
|
||||
{
|
||||
int row = blockIdx.y * WARPS + threadIdx.x;
|
||||
int row = blockIdx.y * WARPSIZE + threadIdx.x;
|
||||
if (row >= rows) return;
|
||||
|
||||
// Softmax
|
||||
@@ -206,7 +206,7 @@ __global__ void softmax4_topk_norm_kernel
|
||||
const int topk
|
||||
)
|
||||
{
|
||||
int row = blockIdx.y * WARPS + threadIdx.x;
|
||||
int row = blockIdx.y * WARPSIZE + threadIdx.x;
|
||||
if (row >= rows) return;
|
||||
|
||||
// Softmax
|
||||
@@ -268,3 +268,97 @@ __global__ void softmax4_topk_norm_kernel
|
||||
logits_int2.y = l23.as_uint32;
|
||||
*row_ptr = logits_int2;
|
||||
}
|
||||
|
||||
__global__ void softmax128_topk_norm_kernel
|
||||
(
|
||||
half* __restrict__ x,
|
||||
const int rows,
|
||||
const int topk
|
||||
)
|
||||
{
|
||||
const int row = blockIdx.y * WARPSIZE + threadIdx.x;
|
||||
if (row >= rows) return;
|
||||
|
||||
register float f[128];
|
||||
|
||||
int4* row_ptr = reinterpret_cast<int4*>(x + row * 128);
|
||||
|
||||
#pragma unroll
|
||||
for (int v = 0; v < 16; ++v) // 16 × 8 halfs = 128 halfs
|
||||
{
|
||||
int4 v4 = row_ptr[v];
|
||||
|
||||
half2_uint32 h0(v4.x), h1(v4.y), h2(v4.z), h3(v4.w);
|
||||
|
||||
const int base = v * 8;
|
||||
f[base + 0] = __low2float (h0.as_half2);
|
||||
f[base + 1] = __high2float(h0.as_half2);
|
||||
f[base + 2] = __low2float (h1.as_half2);
|
||||
f[base + 3] = __high2float(h1.as_half2);
|
||||
f[base + 4] = __low2float (h2.as_half2);
|
||||
f[base + 5] = __high2float(h2.as_half2);
|
||||
f[base + 6] = __low2float (h3.as_half2);
|
||||
f[base + 7] = __high2float(h3.as_half2);
|
||||
}
|
||||
|
||||
float maxf = -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 128; ++i) maxf = fmaxf(maxf, f[i]);
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 128; ++i)
|
||||
{
|
||||
float e = __expf(f[i] - maxf);
|
||||
f[i] = e;
|
||||
sum += e;
|
||||
}
|
||||
|
||||
constexpr float epsilon = 1e-8f;
|
||||
const float isum = 1.f / (sum + 128.0f * epsilon);
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 128; ++i) f[i] = f[i] * isum + epsilon;
|
||||
|
||||
float remaining = 1.0f;
|
||||
for (int drop = 0; drop < 128 - topk; ++drop)
|
||||
{
|
||||
float minv = 1.0f;
|
||||
int mini = -1;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 128; ++j)
|
||||
{
|
||||
if (f[j] > 0.0f && f[j] < minv)
|
||||
{
|
||||
minv = f[j];
|
||||
mini = j;
|
||||
}
|
||||
}
|
||||
remaining -= f[mini];
|
||||
f[mini] = 0.0f;
|
||||
}
|
||||
|
||||
const float inv_remaining = 1.f / remaining;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 128; ++i) f[i] *= inv_remaining;
|
||||
|
||||
#pragma unroll
|
||||
for (int v = 0; v < 16; ++v)
|
||||
{
|
||||
const int base = v * 8;
|
||||
|
||||
half2_uint32 h0, h1, h2, h3;
|
||||
h0.as_half2 = __floats2half2_rn(f[base + 0], f[base + 1]);
|
||||
h1.as_half2 = __floats2half2_rn(f[base + 2], f[base + 3]);
|
||||
h2.as_half2 = __floats2half2_rn(f[base + 4], f[base + 5]);
|
||||
h3.as_half2 = __floats2half2_rn(f[base + 6], f[base + 7]);
|
||||
|
||||
int4 v4;
|
||||
v4.x = h0.as_uint32;
|
||||
v4.y = h1.as_uint32;
|
||||
v4.z = h2.as_uint32;
|
||||
v4.w = h3.as_uint32;
|
||||
|
||||
row_ptr[v] = v4;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,7 +167,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
|
||||
def scratch_space(self) -> int:
|
||||
|
||||
assert self.model.config.intermediate_size >= self.model.config.hidden_size
|
||||
# assert self.model.config.intermediate_size >= self.model.config.hidden_size
|
||||
return self.temp_state_size() + \
|
||||
self.temp_gathered_state_size() + \
|
||||
self.temp_a_size() + \
|
||||
@@ -235,7 +235,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
|
||||
# TODO: LoRA currently uses the Torch codepath. Needs conditional (early-exit) kernels with output scaling
|
||||
# for the LoRA matmuls in order to work with the C++ path
|
||||
|
||||
if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16] or (loras is not None and len(loras) > 0):
|
||||
if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16, 128] or (loras is not None and len(loras) > 0):
|
||||
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs)
|
||||
|
||||
# if loras is None or self.temp_lora_size == 0:
|
||||
|
||||
Reference in New Issue
Block a user