diff --git a/exllamav2/architecture.py b/exllamav2/architecture.py index 300fe0c..f318dd4 100644 --- a/exllamav2/architecture.py +++ b/exllamav2/architecture.py @@ -53,6 +53,10 @@ layer_keys_mixtral_mlp = [["block_sparse_moe.experts.*.w1"], ["block_sparse_moe.experts.*.w2"], ["block_sparse_moe.experts.*.w3"], ["block_sparse_moe.gate"]] +layer_keys_qwen3moe_mlp = [["mlp.experts.*.gate_proj"], + ["mlp.experts.*.up_proj"], + ["mlp.experts.*.down_proj"], + ["mlp.gate"]] layer_keys_dbrx_mlp = [["block_sparse_moe.experts.*.v1", "block_sparse_moe.experts.v1"], ["block_sparse_moe.experts.*.w1", "block_sparse_moe.experts.w1"], ["block_sparse_moe.experts.*.w2", "block_sparse_moe.experts.w2"], @@ -428,6 +432,39 @@ class ExLlamaV2ArchParams: self.lm.attention_bias_qkv = True self.lm.supports_tp = True + # Qwen3 + + if arch_string == "Qwen3ForCausalLM": + arch_recognized = True + self.lm.layer_keys += \ + layer_keys_llama_norms + \ + layer_keys_llama_attn + \ + layer_keys_llama_mlp + self.lm.expect_keys += \ + expect_keys_llama + self.lm.supports_tp = True + self.lm.default_use_qk_norm = True + + # Qwen3MoE + + if arch_string == "Qwen3MoeForCausalLM": + arch_recognized = True + self.lm.layer_keys += \ + layer_keys_llama_norms + \ + layer_keys_llama_attn + \ + layer_keys_qwen3moe_mlp + self.lm.expect_keys += \ + expect_keys_llama + self.lm.supports_tp = True + self.lm.default_use_qk_norm = True + self.lm.keys.update({ + "mlp_gate": ".mlp.experts.*.gate_proj", + "mlp_up": ".mlp.experts.*.up_proj", + "mlp_down": ".mlp.experts.*.down_proj", + "mlp_expert_gate": ".mlp.gate" + }) + self.lm.is_moe = True + # Qwen2-VL (2, 2.5) if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]: diff --git a/exllamav2/config.py b/exllamav2/config.py index ec5abb8..260f228 100644 --- a/exllamav2/config.py +++ b/exllamav2/config.py @@ -319,9 +319,12 @@ class ExLlamaV2Config: default_intermediate_size, opt_subkey = "text_config", ) - self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None) + self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts", "num_experts"], None) self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None) + if self.arch.lm.is_moe: + self.intermediate_size = read(read_config, int, ["moe_intermediate_size"], self.intermediate_size) + # Logit/embedding/residual scale self.logit_scale = read(read_config, float, "logit_scale", 1) diff --git a/exllamav2/conversion/adaptivegptq.py b/exllamav2/conversion/adaptivegptq.py index 22382ba..1c8e84a 100644 --- a/exllamav2/conversion/adaptivegptq.py +++ b/exllamav2/conversion/adaptivegptq.py @@ -229,7 +229,10 @@ class AdaptiveGPTQ: with torch.inference_mode(): - self.hessian /= self.num_batches + if self.hessian is None or self.num_batches == 0: + self.hessian = torch.eye(self.rows, device = self.device, dtype = torch.float) + else: + self.hessian /= self.num_batches diagonal = torch.diag(self.hessian) # Prepare weights diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp.cu b/exllamav2/exllamav2_ext/cuda/q_mlp.cu index 8e9cda2..d881fc9 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp.cu +++ b/exllamav2/exllamav2_ext/cuda/q_mlp.cu @@ -324,9 +324,15 @@ void QMoEMLP::forward_ // half* lora_temp ) { - if (num_experts != 4 && num_experts != 8 && num_experts != 16) + if (rows > MAX_Q_GEMM_WEIGHTS) { - printf(" ## num_experts must be 4, 8 or 16\n"); + printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS); + DBGI(rows); + } + + if (num_experts != 4 && num_experts != 8 && num_experts != 16 && num_experts != 128) + { + printf(" ## num_experts must be 4, 8, 16 or 128\n"); return; } @@ -354,36 +360,33 @@ void QMoEMLP::forward_ &beta_, temp_logits, num_experts); - // Compute softmax filter to and normalize top-k outputs + // Select activation kernel - dim3 blockDim, gridDim; - blockDim.x = WARPS; - blockDim.y = 1; - gridDim.x = 1; - gridDim.y = DIVIDE(rows, WARPS); - if (num_experts == 4) - softmax4_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); - else if (num_experts == 8) - softmax8_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); - else if (num_experts == 16) - softmax16_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + int intermediate_size = w1[0]->width; + fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu); // For small no. rows, execute all kernels but pass the routing weights. Rows with a weight of zero will skip dot // product accum and kernels launched with only zero-weights will exit prematurely. - if (rows <= MAX_Q_GEMM_WEIGHTS) + if (num_experts == 4 || num_experts == 8 || num_experts == 16) { - int intermediate_size = w1[0]->width; - fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu); + dim3 blockDim, gridDim; + blockDim.x = WARPSIZE; + blockDim.y = 1; + gridDim.x = 1; + gridDim.y = DIVIDE(rows, WARPSIZE); + if (num_experts == 4) + softmax4_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + else if (num_experts == 8) + softmax8_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + else if (num_experts == 16) + softmax16_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); for (int i = 0; i < num_experts; i++) { gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); -// apply_loras_cuda(cublas_handle, w1_lora[i], loras, w1[i], temp_state, temp_a, lora_temp, rows); -// apply_loras_cuda(cublas_handle, w3_lora[i], loras, w3[i], temp_state, temp_b, lora_temp, rows); - blockDim.x = THREADS_X; blockDim.y = THREADS_Y; gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1); @@ -391,17 +394,43 @@ void QMoEMLP::forward_ kernel<<>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts); gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true); - -// apply_loras_cuda(cublas_handle, w2_lora[i], loras, w2[i], temp_a, x, lora_temp, rows); } - } + } - // Gather larger number of rows in separate batches according to which experts they trigger, evaluate each MLP - // only on the affected rows and scale by routing weights while adding back directly onto the residual hidden state + // For very large number of experts (Qwen3 etc.) copy to CPU, synchronize and only launch top K experts. This is + // not optimal but the kernel launch overhead is very severe otherwise. Really needs a graph - else + else if (num_experts == 128) { - printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS); - DBGI(rows); + dim3 blockDim, gridDim; + blockDim.x = WARPSIZE; + blockDim.y = 1; + gridDim.x = 1; + gridDim.y = DIVIDE(rows, WARPSIZE); + softmax128_topk_norm_kernel<<>>(temp_logits, rows, num_experts_per_token); + + half* h_logits; + h_logits = (half*) malloc(128 * sizeof(half)); + cudaMemcpyAsync(h_logits, temp_logits, 128 * sizeof(half), cudaMemcpyDeviceToHost, stream); + cudaStreamSynchronize(stream); + + for (int i = 0; i < num_experts; i++) + { + uint16_t w = __half_as_ushort(h_logits[i]); + if (!w) continue; + + gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); + gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false); + + blockDim.x = THREADS_X; + blockDim.y = THREADS_Y; + gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1); + gridDim.y = DIVIDE(rows, THREADS_Y); + kernel<<>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts); + + gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true); + } + + free(h_logits); } } diff --git a/exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh b/exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh index 2cf1b59..c97b66b 100644 --- a/exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh +++ b/exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh @@ -1,5 +1,5 @@ -#define WARPS 32 +#define WARPSIZE 32 __global__ void softmax16_topk_norm_kernel ( @@ -8,7 +8,7 @@ __global__ void softmax16_topk_norm_kernel const int topk ) { - int row = blockIdx.y * WARPS + threadIdx.x; + int row = blockIdx.y * WARPSIZE + threadIdx.x; if (row >= rows) return; // Softmax @@ -122,7 +122,7 @@ __global__ void softmax8_topk_norm_kernel const int topk ) { - int row = blockIdx.y * WARPS + threadIdx.x; + int row = blockIdx.y * WARPSIZE + threadIdx.x; if (row >= rows) return; // Softmax @@ -206,7 +206,7 @@ __global__ void softmax4_topk_norm_kernel const int topk ) { - int row = blockIdx.y * WARPS + threadIdx.x; + int row = blockIdx.y * WARPSIZE + threadIdx.x; if (row >= rows) return; // Softmax @@ -268,3 +268,97 @@ __global__ void softmax4_topk_norm_kernel logits_int2.y = l23.as_uint32; *row_ptr = logits_int2; } + +__global__ void softmax128_topk_norm_kernel +( + half* __restrict__ x, + const int rows, + const int topk +) +{ + const int row = blockIdx.y * WARPSIZE + threadIdx.x; + if (row >= rows) return; + + register float f[128]; + + int4* row_ptr = reinterpret_cast(x + row * 128); + + #pragma unroll + for (int v = 0; v < 16; ++v) // 16 × 8 halfs = 128 halfs + { + int4 v4 = row_ptr[v]; + + half2_uint32 h0(v4.x), h1(v4.y), h2(v4.z), h3(v4.w); + + const int base = v * 8; + f[base + 0] = __low2float (h0.as_half2); + f[base + 1] = __high2float(h0.as_half2); + f[base + 2] = __low2float (h1.as_half2); + f[base + 3] = __high2float(h1.as_half2); + f[base + 4] = __low2float (h2.as_half2); + f[base + 5] = __high2float(h2.as_half2); + f[base + 6] = __low2float (h3.as_half2); + f[base + 7] = __high2float(h3.as_half2); + } + + float maxf = -FLT_MAX; + #pragma unroll + for (int i = 0; i < 128; ++i) maxf = fmaxf(maxf, f[i]); + + float sum = 0.f; + #pragma unroll + for (int i = 0; i < 128; ++i) + { + float e = __expf(f[i] - maxf); + f[i] = e; + sum += e; + } + + constexpr float epsilon = 1e-8f; + const float isum = 1.f / (sum + 128.0f * epsilon); + + #pragma unroll + for (int i = 0; i < 128; ++i) f[i] = f[i] * isum + epsilon; + + float remaining = 1.0f; + for (int drop = 0; drop < 128 - topk; ++drop) + { + float minv = 1.0f; + int mini = -1; + #pragma unroll + for (int j = 0; j < 128; ++j) + { + if (f[j] > 0.0f && f[j] < minv) + { + minv = f[j]; + mini = j; + } + } + remaining -= f[mini]; + f[mini] = 0.0f; + } + + const float inv_remaining = 1.f / remaining; + #pragma unroll + for (int i = 0; i < 128; ++i) f[i] *= inv_remaining; + + #pragma unroll + for (int v = 0; v < 16; ++v) + { + const int base = v * 8; + + half2_uint32 h0, h1, h2, h3; + h0.as_half2 = __floats2half2_rn(f[base + 0], f[base + 1]); + h1.as_half2 = __floats2half2_rn(f[base + 2], f[base + 3]); + h2.as_half2 = __floats2half2_rn(f[base + 4], f[base + 5]); + h3.as_half2 = __floats2half2_rn(f[base + 6], f[base + 7]); + + int4 v4; + v4.x = h0.as_uint32; + v4.y = h1.as_uint32; + v4.z = h2.as_uint32; + v4.w = h3.as_uint32; + + row_ptr[v] = v4; + } +} diff --git a/exllamav2/moe_mlp.py b/exllamav2/moe_mlp.py index 02ebbfd..03559d9 100644 --- a/exllamav2/moe_mlp.py +++ b/exllamav2/moe_mlp.py @@ -167,7 +167,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module): def scratch_space(self) -> int: - assert self.model.config.intermediate_size >= self.model.config.hidden_size + # assert self.model.config.intermediate_size >= self.model.config.hidden_size return self.temp_state_size() + \ self.temp_gathered_state_size() + \ self.temp_a_size() + \ @@ -235,7 +235,7 @@ class ExLlamaV2MoEMLP(ExLlamaV2Module): # TODO: LoRA currently uses the Torch codepath. Needs conditional (early-exit) kernels with output scaling # for the LoRA matmuls in order to work with the C++ path - if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16] or (loras is not None and len(loras) > 0): + if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16, 128] or (loras is not None and len(loras) > 0): return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs) # if loras is None or self.temp_lora_size == 0: diff --git a/exllamav2/vlm/vision_tower.py b/exllamav2/vlm/vision_tower.py index 07f40dd..8781e3b 100644 --- a/exllamav2/vlm/vision_tower.py +++ b/exllamav2/vlm/vision_tower.py @@ -42,6 +42,8 @@ class ExLlamaV2VisionTower(ExLlamaV2): km = self.archparams.keys self.modules = [] + self.tp_context = None + # Preprocessor if cfg.vision_model_type == "pixtral":