mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
Add custom activation limits
This commit is contained in:
@@ -24,6 +24,7 @@ void silu_mul_gr
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z,
|
||||
const float act_limit,
|
||||
Graph* graph
|
||||
)
|
||||
{
|
||||
@@ -52,6 +53,7 @@ void silu_mul_gr
|
||||
(const float*) x.data_ptr(),
|
||||
(const float*) y.data_ptr(),
|
||||
(half*) z.data_ptr(),
|
||||
act_limit,
|
||||
numel
|
||||
);
|
||||
|
||||
@@ -69,6 +71,7 @@ void silu_mul_gr
|
||||
(const half*) x.data_ptr(),
|
||||
(const half*) y.data_ptr(),
|
||||
(half*) z.data_ptr(),
|
||||
act_limit,
|
||||
numel
|
||||
);
|
||||
|
||||
@@ -85,10 +88,11 @@ void silu_mul
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z
|
||||
at::Tensor& z,
|
||||
const float act_limit
|
||||
)
|
||||
{
|
||||
silu_mul_gr(x, y, z, nullptr);
|
||||
silu_mul_gr(x, y, z, act_limit, nullptr);
|
||||
}
|
||||
|
||||
// silu(x) * y -> z, in-place if z == x or z == y
|
||||
@@ -98,6 +102,7 @@ void gelu_mul_gr
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z,
|
||||
const float act_limit,
|
||||
Graph* graph
|
||||
)
|
||||
{
|
||||
@@ -126,6 +131,7 @@ void gelu_mul_gr
|
||||
(const float*) x.data_ptr(),
|
||||
(const float*) y.data_ptr(),
|
||||
(half*) z.data_ptr(),
|
||||
act_limit,
|
||||
numel
|
||||
);
|
||||
|
||||
@@ -143,6 +149,7 @@ void gelu_mul_gr
|
||||
(const half*) x.data_ptr(),
|
||||
(const half*) y.data_ptr(),
|
||||
(half*) z.data_ptr(),
|
||||
act_limit,
|
||||
numel
|
||||
);
|
||||
|
||||
@@ -159,10 +166,11 @@ void gelu_mul
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z
|
||||
at::Tensor& z,
|
||||
const float act_limit
|
||||
)
|
||||
{
|
||||
gelu_mul_gr(x, y, z, nullptr);
|
||||
gelu_mul_gr(x, y, z, act_limit, nullptr);
|
||||
}
|
||||
|
||||
// relu^2(x) * y -> z
|
||||
@@ -172,6 +180,7 @@ void relu2_mul_gr
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z,
|
||||
const float act_limit,
|
||||
Graph* graph
|
||||
)
|
||||
{
|
||||
@@ -200,6 +209,7 @@ void relu2_mul_gr
|
||||
(const float*) x.data_ptr(),
|
||||
(const float*) y.data_ptr(),
|
||||
(half*) z.data_ptr(),
|
||||
act_limit,
|
||||
numel
|
||||
);
|
||||
|
||||
@@ -217,6 +227,7 @@ void relu2_mul_gr
|
||||
(const half*) x.data_ptr(),
|
||||
(const half*) y.data_ptr(),
|
||||
(half*) z.data_ptr(),
|
||||
act_limit,
|
||||
numel
|
||||
);
|
||||
|
||||
@@ -233,10 +244,11 @@ void relu2_mul
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z
|
||||
at::Tensor& z,
|
||||
const float act_limit
|
||||
)
|
||||
{
|
||||
relu2_mul_gr(x, y, z, nullptr);
|
||||
relu2_mul_gr(x, y, z, act_limit, nullptr);
|
||||
}
|
||||
|
||||
// xielu(x, alpha_p, alpha_n) -> z
|
||||
|
||||
@@ -8,6 +8,7 @@ void silu_mul_gr
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z,
|
||||
const float act_limit,
|
||||
Graph* graph
|
||||
);
|
||||
|
||||
@@ -15,7 +16,8 @@ void silu_mul
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z
|
||||
at::Tensor& z,
|
||||
const float act_limit
|
||||
);
|
||||
|
||||
void gelu_mul_gr
|
||||
@@ -23,6 +25,7 @@ void gelu_mul_gr
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z,
|
||||
const float act_limit,
|
||||
Graph* graph
|
||||
);
|
||||
|
||||
@@ -30,7 +33,8 @@ void gelu_mul
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z
|
||||
at::Tensor& z,
|
||||
const float act_limit
|
||||
);
|
||||
|
||||
void relu2_mul_gr
|
||||
@@ -38,6 +42,7 @@ void relu2_mul_gr
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z,
|
||||
const float act_limit,
|
||||
Graph* graph
|
||||
);
|
||||
|
||||
@@ -45,7 +50,8 @@ void relu2_mul
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z
|
||||
at::Tensor& z,
|
||||
const float act_limit
|
||||
);
|
||||
|
||||
void xielu_gr
|
||||
|
||||
@@ -110,6 +110,7 @@ void act_mul_kernel_h
|
||||
const half* __restrict__ x,
|
||||
const half* __restrict__ y,
|
||||
half* __restrict__ z,
|
||||
const float act_limit,
|
||||
const size_t numel
|
||||
)
|
||||
{
|
||||
@@ -126,6 +127,13 @@ void act_mul_kernel_h
|
||||
else if constexpr (activation_type == ACT_RELU2)
|
||||
x2 = _relu2(x2);
|
||||
|
||||
if (act_limit != 0.0f)
|
||||
{
|
||||
x2 = __hmax2(x2, __float2half2_rn(-act_limit));
|
||||
x2 = __hmin2(x2, __float2half2_rn(act_limit));
|
||||
y2 = __hmin2(y2, __float2half2_rn(act_limit));
|
||||
}
|
||||
|
||||
((half2*) z)[idx] = __hmul2(x2, y2);
|
||||
}
|
||||
|
||||
@@ -137,6 +145,7 @@ void act_mul_kernel_f
|
||||
const float* __restrict__ x,
|
||||
const float* __restrict__ y,
|
||||
half* __restrict__ z,
|
||||
const float act_limit,
|
||||
const size_t numel
|
||||
)
|
||||
{
|
||||
@@ -162,6 +171,16 @@ void act_mul_kernel_f
|
||||
x2.y = _relu2(x2.y);
|
||||
}
|
||||
|
||||
if (act_limit != 0.0f)
|
||||
{
|
||||
if (x2.x < -act_limit) x2.x = -act_limit;
|
||||
if (x2.y < -act_limit) x2.y = -act_limit;
|
||||
if (x2.x > act_limit) x2.x = act_limit;
|
||||
if (x2.y > act_limit) x2.y = act_limit;
|
||||
if (y2.x > act_limit) y2.x = act_limit;
|
||||
if (y2.y > act_limit) y2.y = act_limit;
|
||||
}
|
||||
|
||||
x2.x *= y2.x;
|
||||
x2.y *= y2.y;
|
||||
half2 r = __float22half2_rn(x2);
|
||||
|
||||
@@ -113,9 +113,9 @@ void BC_BlockSparseMLP::run_bsz1_gr
|
||||
);
|
||||
|
||||
if (act_silu)
|
||||
silu_mul_gr(interm_g, interm_u, interm_a, graph);
|
||||
silu_mul_gr(interm_g, interm_u, interm_a, act_limit, graph);
|
||||
else if (act_gelu)
|
||||
gelu_mul_gr(interm_g, interm_u, interm_a, graph);
|
||||
gelu_mul_gr(interm_g, interm_u, interm_a, act_limit, graph);
|
||||
|
||||
exl3_mgemm_gr
|
||||
(
|
||||
|
||||
@@ -49,6 +49,7 @@ struct BC_BlockSparseMLP
|
||||
bool act_gelu;
|
||||
std::shared_ptr<BC_GatedMLP> shared_experts;
|
||||
std::shared_ptr<BC_LinearFP16> shared_gate;
|
||||
float act_limit;
|
||||
|
||||
Graph graph_bsz1;
|
||||
|
||||
@@ -84,7 +85,8 @@ struct BC_BlockSparseMLP
|
||||
bool _act_silu,
|
||||
bool _act_gelu,
|
||||
std::shared_ptr<BC_GatedMLP> _shared_experts,
|
||||
std::shared_ptr<BC_LinearFP16> _shared_gate
|
||||
std::shared_ptr<BC_LinearFP16> _shared_gate,
|
||||
float _act_limit
|
||||
) :
|
||||
yh (std::move(_yh)),
|
||||
interm_g (std::move(_interm_g)),
|
||||
@@ -116,7 +118,8 @@ struct BC_BlockSparseMLP
|
||||
act_silu (_act_silu),
|
||||
act_gelu (_act_gelu),
|
||||
shared_experts (_shared_experts),
|
||||
shared_gate (_shared_gate)
|
||||
shared_gate (_shared_gate),
|
||||
act_limit (_act_limit)
|
||||
{}
|
||||
|
||||
void run_bsz1_gr
|
||||
|
||||
@@ -31,7 +31,8 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
|
||||
bool,
|
||||
bool,
|
||||
std::shared_ptr<BC_GatedMLP>,
|
||||
std::shared_ptr<BC_LinearFP16>
|
||||
std::shared_ptr<BC_LinearFP16>,
|
||||
float
|
||||
>(),
|
||||
py::arg("yh"),
|
||||
py::arg("interm_g"),
|
||||
@@ -63,6 +64,7 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
|
||||
py::arg("act_silu"),
|
||||
py::arg("act_gelu"),
|
||||
py::arg("shared_experts"),
|
||||
py::arg("shared_gate")
|
||||
py::arg("shared_gate"),
|
||||
py::arg("act_limit")
|
||||
)
|
||||
.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1);
|
||||
|
||||
@@ -41,11 +41,11 @@ void BC_GatedMLP::run_bsz1_gr
|
||||
at::Tensor u = gu.select(0, 1).unsqueeze(0);
|
||||
|
||||
if (act_silu)
|
||||
silu_mul_gr(g, u, a, graph);
|
||||
silu_mul_gr(g, u, a, act_limit, graph);
|
||||
else if (act_gelu)
|
||||
gelu_mul_gr(g, u, a, graph);
|
||||
gelu_mul_gr(g, u, a, act_limit, graph);
|
||||
else if (act_relu2)
|
||||
relu2_mul_gr(g, u, a, graph);
|
||||
relu2_mul_gr(g, u, a, act_limit, graph);
|
||||
|
||||
down->run_gr(a, d, graph);
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ struct BC_GatedMLP
|
||||
bool act_gelu;
|
||||
bool act_relu2;
|
||||
std::shared_ptr<BC_LinearEXL3> down;
|
||||
float act_limit;
|
||||
|
||||
Graph graph_bsz1;
|
||||
|
||||
@@ -40,7 +41,8 @@ struct BC_GatedMLP
|
||||
bool _act_silu,
|
||||
bool _act_gelu,
|
||||
bool _act_relu2,
|
||||
std::shared_ptr<BC_LinearEXL3> _down
|
||||
std::shared_ptr<BC_LinearEXL3> _down,
|
||||
float _act_limit
|
||||
) :
|
||||
guh (std::move(_guh)),
|
||||
gu (std::move(_gu)),
|
||||
@@ -54,7 +56,8 @@ struct BC_GatedMLP
|
||||
act_silu (_act_silu),
|
||||
act_gelu (_act_gelu),
|
||||
act_relu2 (_act_relu2),
|
||||
down (_down)
|
||||
down (_down),
|
||||
act_limit (_act_limit)
|
||||
{}
|
||||
|
||||
void run_bsz1_gr
|
||||
|
||||
@@ -13,7 +13,8 @@ py::class_<BC_GatedMLP, std::shared_ptr<BC_GatedMLP>>(m, "BC_GatedMLP").def
|
||||
bool,
|
||||
bool,
|
||||
bool,
|
||||
std::shared_ptr<BC_LinearEXL3>
|
||||
std::shared_ptr<BC_LinearEXL3>,
|
||||
float
|
||||
>(),
|
||||
py::arg("guh"),
|
||||
py::arg("gu"),
|
||||
@@ -27,6 +28,7 @@ py::class_<BC_GatedMLP, std::shared_ptr<BC_GatedMLP>>(m, "BC_GatedMLP").def
|
||||
py::arg("act_silu"),
|
||||
py::arg("act_gelu"),
|
||||
py::arg("act_relu2"),
|
||||
py::arg("down")
|
||||
py::arg("down"),
|
||||
py::arg("act_limit")
|
||||
)
|
||||
.def("run_bsz1", &BC_GatedMLP::run_bsz1);
|
||||
|
||||
@@ -169,6 +169,7 @@ class BlockSparseMLP(Module):
|
||||
qmap: str | None = None,
|
||||
out_dtype: torch.dtype = None,
|
||||
activation_fn: str = "silu",
|
||||
act_limit: float = 0.0,
|
||||
interm_dtype: torch.dtype = None,
|
||||
router_type: str = "std",
|
||||
routing_gate: Linear | None = None,
|
||||
@@ -197,6 +198,7 @@ class BlockSparseMLP(Module):
|
||||
self.num_local_experts = num_local_experts if num_local_experts is not None else num_experts
|
||||
self.hidden_size = hidden_size
|
||||
self.router_type = router_type
|
||||
self.act_limit = act_limit
|
||||
|
||||
self.routing_first = routing_first
|
||||
self.routing_last = routing_last
|
||||
@@ -458,7 +460,8 @@ class BlockSparseMLP(Module):
|
||||
self.activation_fn == "silu",
|
||||
self.activation_fn == "gelu",
|
||||
sh_exp,
|
||||
sh_gate
|
||||
sh_gate,
|
||||
self.act_limit
|
||||
)
|
||||
|
||||
|
||||
@@ -566,7 +569,7 @@ class BlockSparseMLP(Module):
|
||||
g = self.gates[exp_i].forward(xc, params)
|
||||
u = self.ups[exp_i].forward(xc, params)
|
||||
a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half)
|
||||
self.activation_fn_call(g, u, a)
|
||||
self.activation_fn_call(g, u, a, self.act_limit)
|
||||
return self.downs[exp_i].forward(a, params)
|
||||
|
||||
for expert_idx in range(num_ex):
|
||||
@@ -642,7 +645,7 @@ class BlockSparseMLP(Module):
|
||||
)
|
||||
|
||||
# Activation
|
||||
self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a)
|
||||
self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a, self.act_limit)
|
||||
|
||||
# Down
|
||||
ext.exl3_mgemm(
|
||||
@@ -717,7 +720,7 @@ class BlockSparseMLP(Module):
|
||||
)
|
||||
|
||||
# Activation
|
||||
self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a)
|
||||
self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a, self.act_limit)
|
||||
|
||||
# Down
|
||||
ext.exl3_mgemm(
|
||||
|
||||
@@ -374,6 +374,7 @@ class GatedMLP(Module):
|
||||
activation_fn: str = "silu",
|
||||
intermediate_split_size: int | None = MAX_MLP_INTERMEDIATE,
|
||||
interm_dtype: torch.dtype = None,
|
||||
act_limit: float = 0.0,
|
||||
pad_to = 128,
|
||||
gates: list[Linear | Module] = None,
|
||||
ups: list[Linear | Module] = None,
|
||||
@@ -388,6 +389,7 @@ class GatedMLP(Module):
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_split_size = intermediate_split_size
|
||||
self.pad_to = pad_to
|
||||
self.act_limit = act_limit
|
||||
|
||||
if key_fused_gate_up:
|
||||
assert not intermediate_split_size or intermediate_size <= intermediate_split_size, \
|
||||
@@ -573,6 +575,7 @@ class GatedMLP(Module):
|
||||
self.activation_fn == "gelu",
|
||||
self.activation_fn == "relu2",
|
||||
self.downs[0].inner.bc,
|
||||
self.act_limit,
|
||||
)
|
||||
|
||||
|
||||
@@ -616,7 +619,7 @@ class GatedMLP(Module):
|
||||
g = self.gates[s].forward(x, params)
|
||||
u = self.ups[s].forward(x, params)
|
||||
a = torch.empty_like(u, dtype = torch.half) if self.interm_dtype != torch.half else u
|
||||
self.activation_fn_call(g, u, a)
|
||||
self.activation_fn_call(g, u, a, self.act_limit)
|
||||
d_ = self.downs[s].forward(a, params)
|
||||
|
||||
if d is None: d = d_
|
||||
@@ -653,7 +656,7 @@ class GatedMLP(Module):
|
||||
u = gu[1].view(bsz, q_len, self.multi_gu[s].out_features)
|
||||
|
||||
a = torch.empty_like(u, dtype = torch.half) if self.interm_dtype != torch.half else u
|
||||
self.activation_fn_call(g, u, a)
|
||||
self.activation_fn_call(g, u, a, self.act_limit)
|
||||
d_ = self.downs[s].forward(a, params)
|
||||
|
||||
if d is None: d = d_
|
||||
|
||||
Reference in New Issue
Block a user