mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
Addition kernel
This commit is contained in:
99
exllamav3/exllamav3_ext/add.cu
Normal file
99
exllamav3/exllamav3_ext/add.cu
Normal file
@@ -0,0 +1,99 @@
|
||||
#include <cuda_fp16.h>
|
||||
#include "add.cuh"
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include "util.h"
|
||||
#include "util.cuh"
|
||||
|
||||
#define NUM_THREADS 1024
|
||||
|
||||
#define KERNEL_DEF(xt, yt, zt, kernel, fn) \
|
||||
__launch_bounds__(NUM_THREADS) \
|
||||
__global__ void kernel \
|
||||
( \
|
||||
const xt* __restrict__ x, \
|
||||
const yt* __restrict__ y, \
|
||||
zt* __restrict__ z, \
|
||||
const uint64_t numel \
|
||||
) \
|
||||
{ \
|
||||
uint64_t idx = ((uint64_t)blockIdx.x * NUM_THREADS + (uint64_t)threadIdx.x); \
|
||||
if (idx >= numel) return; \
|
||||
xt a = x[idx]; \
|
||||
yt b = y[idx]; \
|
||||
z[idx] = fn; \
|
||||
}
|
||||
|
||||
KERNEL_DEF(half, half, half, add_kernel_hhh, __hadd(a, b))
|
||||
KERNEL_DEF(half, half, float, add_kernel_hhf, __half2float(__hadd(a, b)))
|
||||
KERNEL_DEF(half, float, half, add_kernel_hfh, __float2half_rn(__half2float(a) + b))
|
||||
KERNEL_DEF(half, float, float, add_kernel_hff, __half2float(a) + b)
|
||||
KERNEL_DEF(float, half, half, add_kernel_fhh, __float2half_rn(a + __half2float(b)))
|
||||
KERNEL_DEF(float, half, float, add_kernel_fhf, a + __half2float(b))
|
||||
KERNEL_DEF(float, float, half, add_kernel_ffh, __float2half_rn(a + b))
|
||||
KERNEL_DEF(float, float, float, add_kernel_fff, a + b)
|
||||
|
||||
#undef KERNEL_DEF
|
||||
|
||||
/*
|
||||
x + y -> z
|
||||
Works inplace if x == z or y == z
|
||||
*/
|
||||
|
||||
void add_gr
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z,
|
||||
Graph* graph
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(x.device());
|
||||
cudaStream_t stream = graph ? graph->capture_stream : at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
auto xt = x.dtype();
|
||||
auto yt = y.dtype();
|
||||
auto zt = z.dtype();
|
||||
uint64_t numel = x.numel();
|
||||
int blocks = (int) CEIL_DIVIDE(numel, (uint64_t) NUM_THREADS);
|
||||
|
||||
#define INSTANCE(xt_, yt_, zt_, xt__, yt__, zt__, kernel) \
|
||||
if (xt == xt_ && yt == yt_ && zt == zt_) \
|
||||
{ \
|
||||
kernel<<<blocks, NUM_THREADS, 0, stream>>> \
|
||||
( \
|
||||
(const xt__*) x.data_ptr(), \
|
||||
(const yt__*) y.data_ptr(), \
|
||||
(zt__*) z.data_ptr(), \
|
||||
numel \
|
||||
); \
|
||||
if (graph) graph->record_param((void*) &kernel, GP_add_x, 0); \
|
||||
if (graph) graph->record_param((void*) &kernel, GP_add_y, 1); \
|
||||
if (graph) graph->record_param((void*) &kernel, GP_add_z, 2); \
|
||||
if (graph) graph->record_param((void*) &kernel, GP_end, 0); \
|
||||
cuda_check(cudaPeekAtLastError()); \
|
||||
}
|
||||
|
||||
INSTANCE(at::kHalf, at::kHalf, at::kHalf, half, half, half , add_kernel_hhh)
|
||||
INSTANCE(at::kHalf, at::kHalf, at::kFloat, half, half, float, add_kernel_hhf)
|
||||
INSTANCE(at::kHalf, at::kFloat, at::kHalf, half, float, half , add_kernel_hfh)
|
||||
INSTANCE(at::kHalf, at::kFloat, at::kFloat, half, float, float, add_kernel_hff)
|
||||
INSTANCE(at::kFloat, at::kHalf, at::kHalf, float, half, half , add_kernel_fhh)
|
||||
INSTANCE(at::kFloat, at::kHalf, at::kFloat, float, half, float, add_kernel_fhf)
|
||||
INSTANCE(at::kFloat, at::kFloat, at::kHalf, float, float, half , add_kernel_ffh)
|
||||
INSTANCE(at::kFloat, at::kFloat, at::kFloat, float, float, float, add_kernel_fff)
|
||||
|
||||
#undef INSTANCE
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
void add
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z
|
||||
)
|
||||
{
|
||||
add_gr(x, y, z, nullptr);
|
||||
}
|
||||
19
exllamav3/exllamav3_ext/add.cuh
Normal file
19
exllamav3/exllamav3_ext/add.cuh
Normal file
@@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include "graph.cuh"
|
||||
|
||||
void add_gr
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z,
|
||||
Graph* graph
|
||||
);
|
||||
|
||||
void add
|
||||
(
|
||||
const at::Tensor& x,
|
||||
const at::Tensor& y,
|
||||
at::Tensor& z
|
||||
);
|
||||
@@ -15,6 +15,7 @@
|
||||
#include "routing.cuh"
|
||||
#include "gdn.cuh"
|
||||
#include "causal_conv1d.cuh"
|
||||
#include "add.cuh"
|
||||
|
||||
#include "quant/quantize.cuh"
|
||||
#include "quant/pack.cuh"
|
||||
@@ -98,6 +99,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
|
||||
m.def("relu2_mul", &relu2_mul, "relu2_mul");
|
||||
m.def("xielu", &xielu, "xielu");
|
||||
m.def("add_sigmoid_gate", &add_sigmoid_gate, "add_sigmoid_gate");
|
||||
m.def("add", &add, "add");
|
||||
|
||||
m.def("gated_delta_net_fused_op", &gated_delta_net_fused_op, "gated_delta_net_fused_op");
|
||||
m.def("cuda_recurrent_gated_delta_rule", &cuda_recurrent_gated_delta_rule, "cuda_recurrent_gated_delta_rule");
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "../hgemm.cuh"
|
||||
#include "../quant/exl3_gemm.cuh"
|
||||
#include "../activation.cuh"
|
||||
#include "../add.cuh"
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> blocksparse_mlp_routing(
|
||||
int bsz,
|
||||
@@ -140,7 +141,7 @@ void BC_BlockSparseMLP::run_bsz1
|
||||
}
|
||||
else
|
||||
{
|
||||
out_d.add_(out_d_sh.value());
|
||||
add(out_d, out_d_sh.value(), out_d);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user