From 99f792dce0f568351be37183404954f4ec6c62d5 Mon Sep 17 00:00:00 2001 From: turboderp <11859846+turboderp@users.noreply.github.com> Date: Sun, 1 Mar 2026 15:14:02 +0100 Subject: [PATCH] Add custom activation limits --- exllamav3/exllamav3_ext/activation.cu | 24 ++++++++++++++----- exllamav3/exllamav3_ext/activation.cuh | 12 +++++++--- .../exllamav3_ext/activation_kernels.cuh | 19 +++++++++++++++ .../libtorch/blocksparse_mlp.cpp | 4 ++-- .../exllamav3_ext/libtorch/blocksparse_mlp.h | 7 ++++-- .../libtorch/blocksparse_mlp_bc.h | 6 +++-- exllamav3/exllamav3_ext/libtorch/mlp.cpp | 6 ++--- exllamav3/exllamav3_ext/libtorch/mlp.h | 7 ++++-- exllamav3/exllamav3_ext/libtorch/mlp_bc.h | 6 +++-- exllamav3/modules/block_sparse_mlp.py | 11 +++++---- exllamav3/modules/mlp.py | 7 ++++-- 11 files changed, 81 insertions(+), 28 deletions(-) diff --git a/exllamav3/exllamav3_ext/activation.cu b/exllamav3/exllamav3_ext/activation.cu index b518f31..1446076 100644 --- a/exllamav3/exllamav3_ext/activation.cu +++ b/exllamav3/exllamav3_ext/activation.cu @@ -24,6 +24,7 @@ void silu_mul_gr const at::Tensor& x, const at::Tensor& y, at::Tensor& z, + const float act_limit, Graph* graph ) { @@ -52,6 +53,7 @@ void silu_mul_gr (const float*) x.data_ptr(), (const float*) y.data_ptr(), (half*) z.data_ptr(), + act_limit, numel ); @@ -69,6 +71,7 @@ void silu_mul_gr (const half*) x.data_ptr(), (const half*) y.data_ptr(), (half*) z.data_ptr(), + act_limit, numel ); @@ -85,10 +88,11 @@ void silu_mul ( const at::Tensor& x, const at::Tensor& y, - at::Tensor& z + at::Tensor& z, + const float act_limit ) { - silu_mul_gr(x, y, z, nullptr); + silu_mul_gr(x, y, z, act_limit, nullptr); } // silu(x) * y -> z, in-place if z == x or z == y @@ -98,6 +102,7 @@ void gelu_mul_gr const at::Tensor& x, const at::Tensor& y, at::Tensor& z, + const float act_limit, Graph* graph ) { @@ -126,6 +131,7 @@ void gelu_mul_gr (const float*) x.data_ptr(), (const float*) y.data_ptr(), (half*) z.data_ptr(), + act_limit, numel ); @@ -143,6 +149,7 @@ void gelu_mul_gr (const half*) x.data_ptr(), (const half*) y.data_ptr(), (half*) z.data_ptr(), + act_limit, numel ); @@ -159,10 +166,11 @@ void gelu_mul ( const at::Tensor& x, const at::Tensor& y, - at::Tensor& z + at::Tensor& z, + const float act_limit ) { - gelu_mul_gr(x, y, z, nullptr); + gelu_mul_gr(x, y, z, act_limit, nullptr); } // relu^2(x) * y -> z @@ -172,6 +180,7 @@ void relu2_mul_gr const at::Tensor& x, const at::Tensor& y, at::Tensor& z, + const float act_limit, Graph* graph ) { @@ -200,6 +209,7 @@ void relu2_mul_gr (const float*) x.data_ptr(), (const float*) y.data_ptr(), (half*) z.data_ptr(), + act_limit, numel ); @@ -217,6 +227,7 @@ void relu2_mul_gr (const half*) x.data_ptr(), (const half*) y.data_ptr(), (half*) z.data_ptr(), + act_limit, numel ); @@ -233,10 +244,11 @@ void relu2_mul ( const at::Tensor& x, const at::Tensor& y, - at::Tensor& z + at::Tensor& z, + const float act_limit ) { - relu2_mul_gr(x, y, z, nullptr); + relu2_mul_gr(x, y, z, act_limit, nullptr); } // xielu(x, alpha_p, alpha_n) -> z diff --git a/exllamav3/exllamav3_ext/activation.cuh b/exllamav3/exllamav3_ext/activation.cuh index 481da2c..67c44f1 100644 --- a/exllamav3/exllamav3_ext/activation.cuh +++ b/exllamav3/exllamav3_ext/activation.cuh @@ -8,6 +8,7 @@ void silu_mul_gr const at::Tensor& x, const at::Tensor& y, at::Tensor& z, + const float act_limit, Graph* graph ); @@ -15,7 +16,8 @@ void silu_mul ( const at::Tensor& x, const at::Tensor& y, - at::Tensor& z + at::Tensor& z, + const float act_limit ); void gelu_mul_gr @@ -23,6 +25,7 @@ void gelu_mul_gr const at::Tensor& x, const at::Tensor& y, at::Tensor& z, + const float act_limit, Graph* graph ); @@ -30,7 +33,8 @@ void gelu_mul ( const at::Tensor& x, const at::Tensor& y, - at::Tensor& z + at::Tensor& z, + const float act_limit ); void relu2_mul_gr @@ -38,6 +42,7 @@ void relu2_mul_gr const at::Tensor& x, const at::Tensor& y, at::Tensor& z, + const float act_limit, Graph* graph ); @@ -45,7 +50,8 @@ void relu2_mul ( const at::Tensor& x, const at::Tensor& y, - at::Tensor& z + at::Tensor& z, + const float act_limit ); void xielu_gr diff --git a/exllamav3/exllamav3_ext/activation_kernels.cuh b/exllamav3/exllamav3_ext/activation_kernels.cuh index 07d1ad5..712a524 100644 --- a/exllamav3/exllamav3_ext/activation_kernels.cuh +++ b/exllamav3/exllamav3_ext/activation_kernels.cuh @@ -110,6 +110,7 @@ void act_mul_kernel_h const half* __restrict__ x, const half* __restrict__ y, half* __restrict__ z, + const float act_limit, const size_t numel ) { @@ -126,6 +127,13 @@ void act_mul_kernel_h else if constexpr (activation_type == ACT_RELU2) x2 = _relu2(x2); + if (act_limit != 0.0f) + { + x2 = __hmax2(x2, __float2half2_rn(-act_limit)); + x2 = __hmin2(x2, __float2half2_rn(act_limit)); + y2 = __hmin2(y2, __float2half2_rn(act_limit)); + } + ((half2*) z)[idx] = __hmul2(x2, y2); } @@ -137,6 +145,7 @@ void act_mul_kernel_f const float* __restrict__ x, const float* __restrict__ y, half* __restrict__ z, + const float act_limit, const size_t numel ) { @@ -162,6 +171,16 @@ void act_mul_kernel_f x2.y = _relu2(x2.y); } + if (act_limit != 0.0f) + { + if (x2.x < -act_limit) x2.x = -act_limit; + if (x2.y < -act_limit) x2.y = -act_limit; + if (x2.x > act_limit) x2.x = act_limit; + if (x2.y > act_limit) x2.y = act_limit; + if (y2.x > act_limit) y2.x = act_limit; + if (y2.y > act_limit) y2.y = act_limit; + } + x2.x *= y2.x; x2.y *= y2.y; half2 r = __float22half2_rn(x2); diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp index 1c384c5..d150a97 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp @@ -113,9 +113,9 @@ void BC_BlockSparseMLP::run_bsz1_gr ); if (act_silu) - silu_mul_gr(interm_g, interm_u, interm_a, graph); + silu_mul_gr(interm_g, interm_u, interm_a, act_limit, graph); else if (act_gelu) - gelu_mul_gr(interm_g, interm_u, interm_a, graph); + gelu_mul_gr(interm_g, interm_u, interm_a, act_limit, graph); exl3_mgemm_gr ( diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h index 5980f44..0a02049 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.h @@ -49,6 +49,7 @@ struct BC_BlockSparseMLP bool act_gelu; std::shared_ptr shared_experts; std::shared_ptr shared_gate; + float act_limit; Graph graph_bsz1; @@ -84,7 +85,8 @@ struct BC_BlockSparseMLP bool _act_silu, bool _act_gelu, std::shared_ptr _shared_experts, - std::shared_ptr _shared_gate + std::shared_ptr _shared_gate, + float _act_limit ) : yh (std::move(_yh)), interm_g (std::move(_interm_g)), @@ -116,7 +118,8 @@ struct BC_BlockSparseMLP act_silu (_act_silu), act_gelu (_act_gelu), shared_experts (_shared_experts), - shared_gate (_shared_gate) + shared_gate (_shared_gate), + act_limit (_act_limit) {} void run_bsz1_gr diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h index c69d9cf..3ee9738 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h @@ -31,7 +31,8 @@ py::class_>(m, "BC_BlockSp bool, bool, std::shared_ptr, - std::shared_ptr + std::shared_ptr, + float >(), py::arg("yh"), py::arg("interm_g"), @@ -63,6 +64,7 @@ py::class_>(m, "BC_BlockSp py::arg("act_silu"), py::arg("act_gelu"), py::arg("shared_experts"), - py::arg("shared_gate") + py::arg("shared_gate"), + py::arg("act_limit") ) .def("run_bsz1", &BC_BlockSparseMLP::run_bsz1); diff --git a/exllamav3/exllamav3_ext/libtorch/mlp.cpp b/exllamav3/exllamav3_ext/libtorch/mlp.cpp index c8bf24f..5ff98ae 100644 --- a/exllamav3/exllamav3_ext/libtorch/mlp.cpp +++ b/exllamav3/exllamav3_ext/libtorch/mlp.cpp @@ -41,11 +41,11 @@ void BC_GatedMLP::run_bsz1_gr at::Tensor u = gu.select(0, 1).unsqueeze(0); if (act_silu) - silu_mul_gr(g, u, a, graph); + silu_mul_gr(g, u, a, act_limit, graph); else if (act_gelu) - gelu_mul_gr(g, u, a, graph); + gelu_mul_gr(g, u, a, act_limit, graph); else if (act_relu2) - relu2_mul_gr(g, u, a, graph); + relu2_mul_gr(g, u, a, act_limit, graph); down->run_gr(a, d, graph); } diff --git a/exllamav3/exllamav3_ext/libtorch/mlp.h b/exllamav3/exllamav3_ext/libtorch/mlp.h index 332eab2..97e26af 100644 --- a/exllamav3/exllamav3_ext/libtorch/mlp.h +++ b/exllamav3/exllamav3_ext/libtorch/mlp.h @@ -23,6 +23,7 @@ struct BC_GatedMLP bool act_gelu; bool act_relu2; std::shared_ptr down; + float act_limit; Graph graph_bsz1; @@ -40,7 +41,8 @@ struct BC_GatedMLP bool _act_silu, bool _act_gelu, bool _act_relu2, - std::shared_ptr _down + std::shared_ptr _down, + float _act_limit ) : guh (std::move(_guh)), gu (std::move(_gu)), @@ -54,7 +56,8 @@ struct BC_GatedMLP act_silu (_act_silu), act_gelu (_act_gelu), act_relu2 (_act_relu2), - down (_down) + down (_down), + act_limit (_act_limit) {} void run_bsz1_gr diff --git a/exllamav3/exllamav3_ext/libtorch/mlp_bc.h b/exllamav3/exllamav3_ext/libtorch/mlp_bc.h index 22368f6..1162bc9 100644 --- a/exllamav3/exllamav3_ext/libtorch/mlp_bc.h +++ b/exllamav3/exllamav3_ext/libtorch/mlp_bc.h @@ -13,7 +13,8 @@ py::class_>(m, "BC_GatedMLP").def bool, bool, bool, - std::shared_ptr + std::shared_ptr, + float >(), py::arg("guh"), py::arg("gu"), @@ -27,6 +28,7 @@ py::class_>(m, "BC_GatedMLP").def py::arg("act_silu"), py::arg("act_gelu"), py::arg("act_relu2"), - py::arg("down") + py::arg("down"), + py::arg("act_limit") ) .def("run_bsz1", &BC_GatedMLP::run_bsz1); diff --git a/exllamav3/modules/block_sparse_mlp.py b/exllamav3/modules/block_sparse_mlp.py index 32d8433..58eca6b 100644 --- a/exllamav3/modules/block_sparse_mlp.py +++ b/exllamav3/modules/block_sparse_mlp.py @@ -169,6 +169,7 @@ class BlockSparseMLP(Module): qmap: str | None = None, out_dtype: torch.dtype = None, activation_fn: str = "silu", + act_limit: float = 0.0, interm_dtype: torch.dtype = None, router_type: str = "std", routing_gate: Linear | None = None, @@ -197,6 +198,7 @@ class BlockSparseMLP(Module): self.num_local_experts = num_local_experts if num_local_experts is not None else num_experts self.hidden_size = hidden_size self.router_type = router_type + self.act_limit = act_limit self.routing_first = routing_first self.routing_last = routing_last @@ -458,7 +460,8 @@ class BlockSparseMLP(Module): self.activation_fn == "silu", self.activation_fn == "gelu", sh_exp, - sh_gate + sh_gate, + self.act_limit ) @@ -566,7 +569,7 @@ class BlockSparseMLP(Module): g = self.gates[exp_i].forward(xc, params) u = self.ups[exp_i].forward(xc, params) a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half) - self.activation_fn_call(g, u, a) + self.activation_fn_call(g, u, a, self.act_limit) return self.downs[exp_i].forward(a, params) for expert_idx in range(num_ex): @@ -642,7 +645,7 @@ class BlockSparseMLP(Module): ) # Activation - self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a) + self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a, self.act_limit) # Down ext.exl3_mgemm( @@ -717,7 +720,7 @@ class BlockSparseMLP(Module): ) # Activation - self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a) + self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a, self.act_limit) # Down ext.exl3_mgemm( diff --git a/exllamav3/modules/mlp.py b/exllamav3/modules/mlp.py index bd3549f..fcf83a4 100644 --- a/exllamav3/modules/mlp.py +++ b/exllamav3/modules/mlp.py @@ -374,6 +374,7 @@ class GatedMLP(Module): activation_fn: str = "silu", intermediate_split_size: int | None = MAX_MLP_INTERMEDIATE, interm_dtype: torch.dtype = None, + act_limit: float = 0.0, pad_to = 128, gates: list[Linear | Module] = None, ups: list[Linear | Module] = None, @@ -388,6 +389,7 @@ class GatedMLP(Module): self.hidden_size = hidden_size self.intermediate_split_size = intermediate_split_size self.pad_to = pad_to + self.act_limit = act_limit if key_fused_gate_up: assert not intermediate_split_size or intermediate_size <= intermediate_split_size, \ @@ -573,6 +575,7 @@ class GatedMLP(Module): self.activation_fn == "gelu", self.activation_fn == "relu2", self.downs[0].inner.bc, + self.act_limit, ) @@ -616,7 +619,7 @@ class GatedMLP(Module): g = self.gates[s].forward(x, params) u = self.ups[s].forward(x, params) a = torch.empty_like(u, dtype = torch.half) if self.interm_dtype != torch.half else u - self.activation_fn_call(g, u, a) + self.activation_fn_call(g, u, a, self.act_limit) d_ = self.downs[s].forward(a, params) if d is None: d = d_ @@ -653,7 +656,7 @@ class GatedMLP(Module): u = gu[1].view(bsz, q_len, self.multi_gu[s].out_features) a = torch.empty_like(u, dtype = torch.half) if self.interm_dtype != torch.half else u - self.activation_fn_call(g, u, a) + self.activation_fn_call(g, u, a, self.act_limit) d_ = self.downs[s].forward(a, params) if d is None: d = d_