Add basic support for Qwen3MoE

This commit is contained in:
turboderp
2025-05-01 20:23:33 +02:00
parent b422a85c47
commit 68976a07d7
6 changed files with 189 additions and 36 deletions

View File

@@ -53,6 +53,10 @@ layer_keys_mixtral_mlp = [["block_sparse_moe.experts.*.w1"],
["block_sparse_moe.experts.*.w2"],
["block_sparse_moe.experts.*.w3"],
["block_sparse_moe.gate"]]
layer_keys_qwen3moe_mlp = [["mlp.experts.*.gate_proj"],
["mlp.experts.*.up_proj"],
["mlp.experts.*.down_proj"],
["mlp.gate"]]
layer_keys_dbrx_mlp = [["block_sparse_moe.experts.*.v1", "block_sparse_moe.experts.v1"],
["block_sparse_moe.experts.*.w1", "block_sparse_moe.experts.w1"],
["block_sparse_moe.experts.*.w2", "block_sparse_moe.experts.w2"],
@@ -441,6 +445,26 @@ class ExLlamaV2ArchParams:
self.lm.supports_tp = True
self.lm.default_use_qk_norm = True
# Qwen3MoE
if arch_string == "Qwen3MoeForCausalLM":
arch_recognized = True
self.lm.layer_keys += \
layer_keys_llama_norms + \
layer_keys_llama_attn + \
layer_keys_qwen3moe_mlp
self.lm.expect_keys += \
expect_keys_llama
self.lm.supports_tp = True
self.lm.default_use_qk_norm = True
self.lm.keys.update({
"mlp_gate": ".mlp.experts.*.gate_proj",
"mlp_up": ".mlp.experts.*.up_proj",
"mlp_down": ".mlp.experts.*.down_proj",
"mlp_expert_gate": ".mlp.gate"
})
self.lm.is_moe = True
# Qwen2-VL (2, 2.5)
if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:

View File

@@ -319,9 +319,12 @@ class ExLlamaV2Config:
default_intermediate_size,
opt_subkey = "text_config",
)
self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None)
self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts", "num_experts"], None)
self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None)
if self.arch.lm.is_moe:
self.intermediate_size = read(read_config, int, ["moe_intermediate_size"], self.intermediate_size)
# Logit/embedding/residual scale
self.logit_scale = read(read_config, float, "logit_scale", 1)

View File

@@ -229,7 +229,10 @@ class AdaptiveGPTQ:
with torch.inference_mode():
self.hessian /= self.num_batches
if self.hessian is None or self.num_batches == 0:
self.hessian = torch.eye(self.rows, device = self.device, dtype = torch.float)
else:
self.hessian /= self.num_batches
diagonal = torch.diag(self.hessian)
# Prepare weights

View File

@@ -324,9 +324,15 @@ void QMoEMLP::forward_
// half* lora_temp
)
{
if (num_experts != 4 && num_experts != 8 && num_experts != 16)
if (rows > MAX_Q_GEMM_WEIGHTS)
{
printf(" ## num_experts must be 4, 8 or 16\n");
printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS);
DBGI(rows);
}
if (num_experts != 4 && num_experts != 8 && num_experts != 16 && num_experts != 128)
{
printf(" ## num_experts must be 4, 8, 16 or 128\n");
return;
}
@@ -354,36 +360,33 @@ void QMoEMLP::forward_
&beta_,
temp_logits, num_experts);
// Compute softmax filter to and normalize top-k outputs
// Select activation kernel
dim3 blockDim, gridDim;
blockDim.x = WARPS;
blockDim.y = 1;
gridDim.x = 1;
gridDim.y = DIVIDE(rows, WARPS);
if (num_experts == 4)
softmax4_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
else if (num_experts == 8)
softmax8_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
else if (num_experts == 16)
softmax16_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
int intermediate_size = w1[0]->width;
fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu);
// For small no. rows, execute all kernels but pass the routing weights. Rows with a weight of zero will skip dot
// product accum and kernels launched with only zero-weights will exit prematurely.
if (rows <= MAX_Q_GEMM_WEIGHTS)
if (num_experts == 4 || num_experts == 8 || num_experts == 16)
{
int intermediate_size = w1[0]->width;
fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu);
dim3 blockDim, gridDim;
blockDim.x = WARPSIZE;
blockDim.y = 1;
gridDim.x = 1;
gridDim.y = DIVIDE(rows, WARPSIZE);
if (num_experts == 4)
softmax4_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
else if (num_experts == 8)
softmax8_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
else if (num_experts == 16)
softmax16_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
for (int i = 0; i < num_experts; i++)
{
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
// apply_loras_cuda(cublas_handle, w1_lora[i], loras, w1[i], temp_state, temp_a, lora_temp, rows);
// apply_loras_cuda(cublas_handle, w3_lora[i], loras, w3[i], temp_state, temp_b, lora_temp, rows);
blockDim.x = THREADS_X;
blockDim.y = THREADS_Y;
gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1);
@@ -391,17 +394,43 @@ void QMoEMLP::forward_
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts);
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true);
// apply_loras_cuda(cublas_handle, w2_lora[i], loras, w2[i], temp_a, x, lora_temp, rows);
}
}
}
// Gather larger number of rows in separate batches according to which experts they trigger, evaluate each MLP
// only on the affected rows and scale by routing weights while adding back directly onto the residual hidden state
// For very large number of experts (Qwen3 etc.) copy to CPU, synchronize and only launch top K experts. This is
// not optimal but the kernel launch overhead is very severe otherwise. Really needs a graph
else
else if (num_experts == 128)
{
printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS);
DBGI(rows);
dim3 blockDim, gridDim;
blockDim.x = WARPSIZE;
blockDim.y = 1;
gridDim.x = 1;
gridDim.y = DIVIDE(rows, WARPSIZE);
softmax128_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
half* h_logits;
h_logits = (half*) malloc(128 * sizeof(half));
cudaMemcpyAsync(h_logits, temp_logits, 128 * sizeof(half), cudaMemcpyDeviceToHost, stream);
cudaStreamSynchronize(stream);
for (int i = 0; i < num_experts; i++)
{
uint16_t w = __half_as_ushort(h_logits[i]);
if (!w) continue;
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
blockDim.x = THREADS_X;
blockDim.y = THREADS_Y;
gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1);
gridDim.y = DIVIDE(rows, THREADS_Y);
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts);
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true);
}
free(h_logits);
}
}

View File

@@ -1,5 +1,5 @@
#define WARPS 32
#define WARPSIZE 32
__global__ void softmax16_topk_norm_kernel
(
@@ -8,7 +8,7 @@ __global__ void softmax16_topk_norm_kernel
const int topk
)
{
int row = blockIdx.y * WARPS + threadIdx.x;
int row = blockIdx.y * WARPSIZE + threadIdx.x;
if (row >= rows) return;
// Softmax
@@ -122,7 +122,7 @@ __global__ void softmax8_topk_norm_kernel
const int topk
)
{
int row = blockIdx.y * WARPS + threadIdx.x;
int row = blockIdx.y * WARPSIZE + threadIdx.x;
if (row >= rows) return;
// Softmax
@@ -206,7 +206,7 @@ __global__ void softmax4_topk_norm_kernel
const int topk
)
{
int row = blockIdx.y * WARPS + threadIdx.x;
int row = blockIdx.y * WARPSIZE + threadIdx.x;
if (row >= rows) return;
// Softmax
@@ -268,3 +268,97 @@ __global__ void softmax4_topk_norm_kernel
logits_int2.y = l23.as_uint32;
*row_ptr = logits_int2;
}
__global__ void softmax128_topk_norm_kernel
(
half* __restrict__ x,
const int rows,
const int topk
)
{
const int row = blockIdx.y * WARPSIZE + threadIdx.x;
if (row >= rows) return;
register float f[128];
int4* row_ptr = reinterpret_cast<int4*>(x + row * 128);
#pragma unroll
for (int v = 0; v < 16; ++v) // 16 × 8 halfs = 128 halfs
{
int4 v4 = row_ptr[v];
half2_uint32 h0(v4.x), h1(v4.y), h2(v4.z), h3(v4.w);
const int base = v * 8;
f[base + 0] = __low2float (h0.as_half2);
f[base + 1] = __high2float(h0.as_half2);
f[base + 2] = __low2float (h1.as_half2);
f[base + 3] = __high2float(h1.as_half2);
f[base + 4] = __low2float (h2.as_half2);
f[base + 5] = __high2float(h2.as_half2);
f[base + 6] = __low2float (h3.as_half2);
f[base + 7] = __high2float(h3.as_half2);
}
float maxf = -FLT_MAX;
#pragma unroll
for (int i = 0; i < 128; ++i) maxf = fmaxf(maxf, f[i]);
float sum = 0.f;
#pragma unroll
for (int i = 0; i < 128; ++i)
{
float e = __expf(f[i] - maxf);
f[i] = e;
sum += e;
}
constexpr float epsilon = 1e-8f;
const float isum = 1.f / (sum + 128.0f * epsilon);
#pragma unroll
for (int i = 0; i < 128; ++i) f[i] = f[i] * isum + epsilon;
float remaining = 1.0f;
for (int drop = 0; drop < 128 - topk; ++drop)
{
float minv = 1.0f;
int mini = -1;
#pragma unroll
for (int j = 0; j < 128; ++j)
{
if (f[j] > 0.0f && f[j] < minv)
{
minv = f[j];
mini = j;
}
}
remaining -= f[mini];
f[mini] = 0.0f;
}
const float inv_remaining = 1.f / remaining;
#pragma unroll
for (int i = 0; i < 128; ++i) f[i] *= inv_remaining;
#pragma unroll
for (int v = 0; v < 16; ++v)
{
const int base = v * 8;
half2_uint32 h0, h1, h2, h3;
h0.as_half2 = __floats2half2_rn(f[base + 0], f[base + 1]);
h1.as_half2 = __floats2half2_rn(f[base + 2], f[base + 3]);
h2.as_half2 = __floats2half2_rn(f[base + 4], f[base + 5]);
h3.as_half2 = __floats2half2_rn(f[base + 6], f[base + 7]);
int4 v4;
v4.x = h0.as_uint32;
v4.y = h1.as_uint32;
v4.z = h2.as_uint32;
v4.w = h3.as_uint32;
row_ptr[v] = v4;
}
}

View File

@@ -167,7 +167,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
def scratch_space(self) -> int:
assert self.model.config.intermediate_size >= self.model.config.hidden_size
# assert self.model.config.intermediate_size >= self.model.config.hidden_size
return self.temp_state_size() + \
self.temp_gathered_state_size() + \
self.temp_a_size() + \
@@ -235,7 +235,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module):
# TODO: LoRA currently uses the Torch codepath. Needs conditional (early-exit) kernels with output scaling
# for the LoRA matmuls in order to work with the C++ path
if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16] or (loras is not None and len(loras) > 0):
if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16, 128] or (loras is not None and len(loras) > 0):
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs)
# if loras is None or self.temp_lora_size == 0: