Files
exllamav3/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp_bc.h
2026-03-01 17:57:55 +01:00

71 lines
1.6 KiB
C++

py::class_<BC_BlockSparseMLP, std::shared_ptr<BC_BlockSparseMLP>>(m, "BC_BlockSparseMLP").def
(
py::init<
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
at::Tensor,
c10::optional<at::Tensor>,
c10::optional<at::Tensor>,
int,
int,
at::Tensor,
at::Tensor,
at::Tensor,
int,
bool,
bool,
at::Tensor,
at::Tensor,
at::Tensor,
int,
bool,
bool,
at::Tensor,
at::Tensor,
at::Tensor,
int,
bool,
bool,
bool,
bool,
std::shared_ptr<BC_GatedMLP>,
std::shared_ptr<BC_LinearFP16>,
float
>(),
py::arg("yh"),
py::arg("interm_g"),
py::arg("interm_u"),
py::arg("interm_a"),
py::arg("out_d"),
py::arg("out_d_sh"),
py::arg("z"),
py::arg("min_expert"),
py::arg("max_expert"),
py::arg("gate_ptrs_trellis"),
py::arg("gate_ptrs_suh"),
py::arg("gate_ptrs_svh"),
py::arg("gate_K"),
py::arg("gate_mcg"),
py::arg("gate_mul1"),
py::arg("up_ptrs_trellis"),
py::arg("up_ptrs_suh"),
py::arg("up_ptrs_svh"),
py::arg("up_K"),
py::arg("up_mcg"),
py::arg("up_mul1"),
py::arg("down_ptrs_trellis"),
py::arg("down_ptrs_suh"),
py::arg("down_ptrs_svh"),
py::arg("down_K"),
py::arg("down_mcg"),
py::arg("down_mul1"),
py::arg("act_silu"),
py::arg("act_gelu"),
py::arg("shared_experts"),
py::arg("shared_gate"),
py::arg("act_limit")
)
.def("run_bsz1", &BC_BlockSparseMLP::run_bsz1);