Add custom activation limits

This commit is contained in:
turboderp
2026-03-01 15:14:02 +01:00
parent b272ea3515
commit 99f792dce0
11 changed files with 81 additions and 28 deletions

View File

@@ -24,6 +24,7 @@ void silu_mul_gr
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z,
const float act_limit,
Graph* graph
)
{
@@ -52,6 +53,7 @@ void silu_mul_gr
(const float*) x.data_ptr(),
(const float*) y.data_ptr(),
(half*) z.data_ptr(),
act_limit,
numel
);
@@ -69,6 +71,7 @@ void silu_mul_gr
(const half*) x.data_ptr(),
(const half*) y.data_ptr(),
(half*) z.data_ptr(),
act_limit,
numel
);
@@ -85,10 +88,11 @@ void silu_mul
(
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z
at::Tensor& z,
const float act_limit
)
{
silu_mul_gr(x, y, z, nullptr);
silu_mul_gr(x, y, z, act_limit, nullptr);
}
// silu(x) * y -> z, in-place if z == x or z == y
@@ -98,6 +102,7 @@ void gelu_mul_gr
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z,
const float act_limit,
Graph* graph
)
{
@@ -126,6 +131,7 @@ void gelu_mul_gr
(const float*) x.data_ptr(),
(const float*) y.data_ptr(),
(half*) z.data_ptr(),
act_limit,
numel
);
@@ -143,6 +149,7 @@ void gelu_mul_gr
(const half*) x.data_ptr(),
(const half*) y.data_ptr(),
(half*) z.data_ptr(),
act_limit,
numel
);
@@ -159,10 +166,11 @@ void gelu_mul
(
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z
at::Tensor& z,
const float act_limit
)
{
gelu_mul_gr(x, y, z, nullptr);
gelu_mul_gr(x, y, z, act_limit, nullptr);
}
// relu^2(x) * y -> z
@@ -172,6 +180,7 @@ void relu2_mul_gr
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z,
const float act_limit,
Graph* graph
)
{
@@ -200,6 +209,7 @@ void relu2_mul_gr
(const float*) x.data_ptr(),
(const float*) y.data_ptr(),
(half*) z.data_ptr(),
act_limit,
numel
);
@@ -217,6 +227,7 @@ void relu2_mul_gr
(const half*) x.data_ptr(),
(const half*) y.data_ptr(),
(half*) z.data_ptr(),
act_limit,
numel
);
@@ -233,10 +244,11 @@ void relu2_mul
(
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z
at::Tensor& z,
const float act_limit
)
{
relu2_mul_gr(x, y, z, nullptr);
relu2_mul_gr(x, y, z, act_limit, nullptr);
}
// xielu(x, alpha_p, alpha_n) -> z

View File

@@ -8,6 +8,7 @@ void silu_mul_gr
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z,
const float act_limit,
Graph* graph
);
@@ -15,7 +16,8 @@ void silu_mul
(
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z
at::Tensor& z,
const float act_limit
);
void gelu_mul_gr
@@ -23,6 +25,7 @@ void gelu_mul_gr
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z,
const float act_limit,
Graph* graph
);
@@ -30,7 +33,8 @@ void gelu_mul
(
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z
at::Tensor& z,
const float act_limit
);
void relu2_mul_gr
@@ -38,6 +42,7 @@ void relu2_mul_gr
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z,
const float act_limit,
Graph* graph
);
@@ -45,7 +50,8 @@ void relu2_mul
(
const at::Tensor& x,
const at::Tensor& y,
at::Tensor& z
at::Tensor& z,
const float act_limit
);
void xielu_gr

View File

@@ -110,6 +110,7 @@ void act_mul_kernel_h
const half* __restrict__ x,
const half* __restrict__ y,
half* __restrict__ z,
const float act_limit,
const size_t numel
)
{
@@ -126,6 +127,13 @@ void act_mul_kernel_h
else if constexpr (activation_type == ACT_RELU2)
x2 = _relu2(x2);
if (act_limit != 0.0f)
{
x2 = __hmax2(x2, __float2half2_rn(-act_limit));
x2 = __hmin2(x2, __float2half2_rn(act_limit));
y2 = __hmin2(y2, __float2half2_rn(act_limit));
}
((half2*) z)[idx] = __hmul2(x2, y2);
}
@@ -137,6 +145,7 @@ void act_mul_kernel_f
const float* __restrict__ x,
const float* __restrict__ y,
half* __restrict__ z,
const float act_limit,
const size_t numel
)
{
@@ -162,6 +171,16 @@ void act_mul_kernel_f
x2.y = _relu2(x2.y);
}
if (act_limit != 0.0f)
{
if (x2.x < -act_limit) x2.x = -act_limit;
if (x2.y < -act_limit) x2.y = -act_limit;
if (x2.x > act_limit) x2.x = act_limit;
if (x2.y > act_limit) x2.y = act_limit;
if (y2.x > act_limit) y2.x = act_limit;
if (y2.y > act_limit) y2.y = act_limit;
}
x2.x *= y2.x;
x2.y *= y2.y;
half2 r = __float22half2_rn(x2);

View File

@@ -113,9 +113,9 @@ void BC_BlockSparseMLP::run_bsz1_gr
);
if (act_silu)
silu_mul_gr(interm_g, interm_u, interm_a, graph);
silu_mul_gr(interm_g, interm_u, interm_a, act_limit, graph);
else if (act_gelu)
gelu_mul_gr(interm_g, interm_u, interm_a, graph);
gelu_mul_gr(interm_g, interm_u, interm_a, act_limit, graph);
exl3_mgemm_gr
(

View File

@@ -49,6 +49,7 @@ struct BC_BlockSparseMLP
bool act_gelu;
std::shared_ptr<BC_GatedMLP> shared_experts;
std::shared_ptr<BC_LinearFP16> shared_gate;
float act_limit;
Graph graph_bsz1;
@@ -84,7 +85,8 @@ struct BC_BlockSparseMLP
bool _act_silu,
bool _act_gelu,
std::shared_ptr<BC_GatedMLP> _shared_experts,
std::shared_ptr<BC_LinearFP16> _shared_gate
std::shared_ptr<BC_LinearFP16> _shared_gate,
float _act_limit
) :
yh (std::move(_yh)),
interm_g (std::move(_interm_g)),
@@ -116,7 +118,8 @@ struct BC_BlockSparseMLP
act_silu (_act_silu),
act_gelu (_act_gelu),
shared_experts (_shared_experts),
shared_gate (_shared_gate)
shared_gate (_shared_gate),
act_limit (_act_limit)
{}
void run_bsz1_gr

View File

@@ -31,7 +31,8 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
bool,
bool,
std::shared_ptr<BC_GatedMLP>,
std::shared_ptr<BC_LinearFP16>
std::shared_ptr<BC_LinearFP16>,
float
>(),
py::arg("yh"),
py::arg("interm_g"),
@@ -63,6 +64,7 @@ py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSp
py::arg("act_silu"),
py::arg("act_gelu"),
py::arg("shared_experts"),
py::arg("shared_gate")
py::arg("shared_gate"),
py::arg("act_limit")
)
.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1);

View File

@@ -41,11 +41,11 @@ void BC_GatedMLP::run_bsz1_gr
at::Tensor u = gu.select(0, 1).unsqueeze(0);
if (act_silu)
silu_mul_gr(g, u, a, graph);
silu_mul_gr(g, u, a, act_limit, graph);
else if (act_gelu)
gelu_mul_gr(g, u, a, graph);
gelu_mul_gr(g, u, a, act_limit, graph);
else if (act_relu2)
relu2_mul_gr(g, u, a, graph);
relu2_mul_gr(g, u, a, act_limit, graph);
down->run_gr(a, d, graph);
}

View File

@@ -23,6 +23,7 @@ struct BC_GatedMLP
bool act_gelu;
bool act_relu2;
std::shared_ptr<BC_LinearEXL3> down;
float act_limit;
Graph graph_bsz1;
@@ -40,7 +41,8 @@ struct BC_GatedMLP
bool _act_silu,
bool _act_gelu,
bool _act_relu2,
std::shared_ptr<BC_LinearEXL3> _down
std::shared_ptr<BC_LinearEXL3> _down,
float _act_limit
) :
guh (std::move(_guh)),
gu (std::move(_gu)),
@@ -54,7 +56,8 @@ struct BC_GatedMLP
act_silu (_act_silu),
act_gelu (_act_gelu),
act_relu2 (_act_relu2),
down (_down)
down (_down),
act_limit (_act_limit)
{}
void run_bsz1_gr

View File

@@ -13,7 +13,8 @@ py::class_<BC_GatedMLP, std::shared_ptr<BC_GatedMLP>>(m, "BC_GatedMLP").def
bool,
bool,
bool,
std::shared_ptr<BC_LinearEXL3>
std::shared_ptr<BC_LinearEXL3>,
float
>(),
py::arg("guh"),
py::arg("gu"),
@@ -27,6 +28,7 @@ py::class_<BC_GatedMLP, std::shared_ptr<BC_GatedMLP>>(m, "BC_GatedMLP").def
py::arg("act_silu"),
py::arg("act_gelu"),
py::arg("act_relu2"),
py::arg("down")
py::arg("down"),
py::arg("act_limit")
)
.def("run_bsz1", &BC_GatedMLP::run_bsz1);

View File

@@ -169,6 +169,7 @@ class BlockSparseMLP(Module):
qmap: str | None = None,
out_dtype: torch.dtype = None,
activation_fn: str = "silu",
act_limit: float = 0.0,
interm_dtype: torch.dtype = None,
router_type: str = "std",
routing_gate: Linear | None = None,
@@ -197,6 +198,7 @@ class BlockSparseMLP(Module):
self.num_local_experts = num_local_experts if num_local_experts is not None else num_experts
self.hidden_size = hidden_size
self.router_type = router_type
self.act_limit = act_limit
self.routing_first = routing_first
self.routing_last = routing_last
@@ -458,7 +460,8 @@ class BlockSparseMLP(Module):
self.activation_fn == "silu",
self.activation_fn == "gelu",
sh_exp,
sh_gate
sh_gate,
self.act_limit
)
@@ -566,7 +569,7 @@ class BlockSparseMLP(Module):
g = self.gates[exp_i].forward(xc, params)
u = self.ups[exp_i].forward(xc, params)
a = u if self.interm_dtype == torch.half else torch.empty_like(u, dtype = torch.half)
self.activation_fn_call(g, u, a)
self.activation_fn_call(g, u, a, self.act_limit)
return self.downs[exp_i].forward(a, params)
for expert_idx in range(num_ex):
@@ -642,7 +645,7 @@ class BlockSparseMLP(Module):
)
# Activation
self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a)
self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a, self.act_limit)
# Down
ext.exl3_mgemm(
@@ -717,7 +720,7 @@ class BlockSparseMLP(Module):
)
# Activation
self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a)
self.activation_fn_call(cfg.interm_g, cfg.interm_u, cfg.interm_a, self.act_limit)
# Down
ext.exl3_mgemm(

View File

@@ -374,6 +374,7 @@ class GatedMLP(Module):
activation_fn: str = "silu",
intermediate_split_size: int | None = MAX_MLP_INTERMEDIATE,
interm_dtype: torch.dtype = None,
act_limit: float = 0.0,
pad_to = 128,
gates: list[Linear | Module] = None,
ups: list[Linear | Module] = None,
@@ -388,6 +389,7 @@ class GatedMLP(Module):
self.hidden_size = hidden_size
self.intermediate_split_size = intermediate_split_size
self.pad_to = pad_to
self.act_limit = act_limit
if key_fused_gate_up:
assert not intermediate_split_size or intermediate_size <= intermediate_split_size, \
@@ -573,6 +575,7 @@ class GatedMLP(Module):
self.activation_fn == "gelu",
self.activation_fn == "relu2",
self.downs[0].inner.bc,
self.act_limit,
)
@@ -616,7 +619,7 @@ class GatedMLP(Module):
g = self.gates[s].forward(x, params)
u = self.ups[s].forward(x, params)
a = torch.empty_like(u, dtype = torch.half) if self.interm_dtype != torch.half else u
self.activation_fn_call(g, u, a)
self.activation_fn_call(g, u, a, self.act_limit)
d_ = self.downs[s].forward(a, params)
if d is None: d = d_
@@ -653,7 +656,7 @@ class GatedMLP(Module):
u = gu[1].view(bsz, q_len, self.multi_gu[s].out_features)
a = torch.empty_like(u, dtype = torch.half) if self.interm_dtype != torch.half else u
self.activation_fn_call(g, u, a)
self.activation_fn_call(g, u, a, self.act_limit)
d_ = self.downs[s].forward(a, params)
if d is None: d = d_