diff --git a/exllamav3/exllamav3_ext/add.cu b/exllamav3/exllamav3_ext/add.cu new file mode 100644 index 0000000..94e5754 --- /dev/null +++ b/exllamav3/exllamav3_ext/add.cu @@ -0,0 +1,99 @@ +#include +#include "add.cuh" +#include +#include +#include "util.h" +#include "util.cuh" + +#define NUM_THREADS 1024 + +#define KERNEL_DEF(xt, yt, zt, kernel, fn) \ +__launch_bounds__(NUM_THREADS) \ +__global__ void kernel \ +( \ + const xt* __restrict__ x, \ + const yt* __restrict__ y, \ + zt* __restrict__ z, \ + const uint64_t numel \ +) \ +{ \ + uint64_t idx = ((uint64_t)blockIdx.x * NUM_THREADS + (uint64_t)threadIdx.x); \ + if (idx >= numel) return; \ + xt a = x[idx]; \ + yt b = y[idx]; \ + z[idx] = fn; \ +} + +KERNEL_DEF(half, half, half, add_kernel_hhh, __hadd(a, b)) +KERNEL_DEF(half, half, float, add_kernel_hhf, __half2float(__hadd(a, b))) +KERNEL_DEF(half, float, half, add_kernel_hfh, __float2half_rn(__half2float(a) + b)) +KERNEL_DEF(half, float, float, add_kernel_hff, __half2float(a) + b) +KERNEL_DEF(float, half, half, add_kernel_fhh, __float2half_rn(a + __half2float(b))) +KERNEL_DEF(float, half, float, add_kernel_fhf, a + __half2float(b)) +KERNEL_DEF(float, float, half, add_kernel_ffh, __float2half_rn(a + b)) +KERNEL_DEF(float, float, float, add_kernel_fff, a + b) + +#undef KERNEL_DEF + +/* +x + y -> z +Works inplace if x == z or y == z +*/ + +void add_gr +( + const at::Tensor& x, + const at::Tensor& y, + at::Tensor& z, + Graph* graph +) +{ + const at::cuda::OptionalCUDAGuard device_guard(x.device()); + cudaStream_t stream = graph ? graph->capture_stream : at::cuda::getCurrentCUDAStream().stream(); + + auto xt = x.dtype(); + auto yt = y.dtype(); + auto zt = z.dtype(); + uint64_t numel = x.numel(); + int blocks = (int) CEIL_DIVIDE(numel, (uint64_t) NUM_THREADS); + + #define INSTANCE(xt_, yt_, zt_, xt__, yt__, zt__, kernel) \ + if (xt == xt_ && yt == yt_ && zt == zt_) \ + { \ + kernel<<>> \ + ( \ + (const xt__*) x.data_ptr(), \ + (const yt__*) y.data_ptr(), \ + (zt__*) z.data_ptr(), \ + numel \ + ); \ + if (graph) graph->record_param((void*) &kernel, GP_add_x, 0); \ + if (graph) graph->record_param((void*) &kernel, GP_add_y, 1); \ + if (graph) graph->record_param((void*) &kernel, GP_add_z, 2); \ + if (graph) graph->record_param((void*) &kernel, GP_end, 0); \ + cuda_check(cudaPeekAtLastError()); \ + } + + INSTANCE(at::kHalf, at::kHalf, at::kHalf, half, half, half , add_kernel_hhh) + INSTANCE(at::kHalf, at::kHalf, at::kFloat, half, half, float, add_kernel_hhf) + INSTANCE(at::kHalf, at::kFloat, at::kHalf, half, float, half , add_kernel_hfh) + INSTANCE(at::kHalf, at::kFloat, at::kFloat, half, float, float, add_kernel_hff) + INSTANCE(at::kFloat, at::kHalf, at::kHalf, float, half, half , add_kernel_fhh) + INSTANCE(at::kFloat, at::kHalf, at::kFloat, float, half, float, add_kernel_fhf) + INSTANCE(at::kFloat, at::kFloat, at::kHalf, float, float, half , add_kernel_ffh) + INSTANCE(at::kFloat, at::kFloat, at::kFloat, float, float, float, add_kernel_fff) + + #undef INSTANCE + + cuda_check(cudaPeekAtLastError()); +} + +void add +( + const at::Tensor& x, + const at::Tensor& y, + at::Tensor& z +) +{ + add_gr(x, y, z, nullptr); +} diff --git a/exllamav3/exllamav3_ext/add.cuh b/exllamav3/exllamav3_ext/add.cuh new file mode 100644 index 0000000..6a6913b --- /dev/null +++ b/exllamav3/exllamav3_ext/add.cuh @@ -0,0 +1,19 @@ +#pragma once + +#include +#include "graph.cuh" + +void add_gr +( + const at::Tensor& x, + const at::Tensor& y, + at::Tensor& z, + Graph* graph +); + +void add +( + const at::Tensor& x, + const at::Tensor& y, + at::Tensor& z +); \ No newline at end of file diff --git a/exllamav3/exllamav3_ext/bindings.cpp b/exllamav3/exllamav3_ext/bindings.cpp index 4ee20b7..4c68f1d 100644 --- a/exllamav3/exllamav3_ext/bindings.cpp +++ b/exllamav3/exllamav3_ext/bindings.cpp @@ -15,6 +15,7 @@ #include "routing.cuh" #include "gdn.cuh" #include "causal_conv1d.cuh" +#include "add.cuh" #include "quant/quantize.cuh" #include "quant/pack.cuh" @@ -98,6 +99,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) m.def("relu2_mul", &relu2_mul, "relu2_mul"); m.def("xielu", &xielu, "xielu"); m.def("add_sigmoid_gate", &add_sigmoid_gate, "add_sigmoid_gate"); + m.def("add", &add, "add"); m.def("gated_delta_net_fused_op", &gated_delta_net_fused_op, "gated_delta_net_fused_op"); m.def("cuda_recurrent_gated_delta_rule", &cuda_recurrent_gated_delta_rule, "cuda_recurrent_gated_delta_rule"); diff --git a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp index 5ccbf79..e0e871b 100644 --- a/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp +++ b/exllamav3/exllamav3_ext/libtorch/blocksparse_mlp.cpp @@ -7,6 +7,7 @@ #include "../hgemm.cuh" #include "../quant/exl3_gemm.cuh" #include "../activation.cuh" +#include "../add.cuh" std::tuple blocksparse_mlp_routing( int bsz, @@ -140,7 +141,7 @@ void BC_BlockSparseMLP::run_bsz1 } else { - out_d.add_(out_d_sh.value()); + add(out_d, out_d_sh.value(), out_d); } }