mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
Alias __nv_bfloat16 -> bfloat16
This commit is contained in:
@@ -6,6 +6,7 @@
|
||||
#include "util.cuh"
|
||||
|
||||
#define NUM_THREADS 1024
|
||||
using bfloat16 = __nv_bfloat16;
|
||||
|
||||
template <int num_threads>
|
||||
__device__ inline float reduce(float sum, int warp_id, int lane_id)
|
||||
@@ -298,10 +299,10 @@ template <int num_threads, typename output_t>
|
||||
__global__ __launch_bounds__(num_threads)
|
||||
void gated_rms_norm_kernel
|
||||
(
|
||||
const __nv_bfloat16* __restrict__ x,
|
||||
const __nv_bfloat16* __restrict__ w,
|
||||
const bfloat16* __restrict__ x,
|
||||
const bfloat16* __restrict__ w,
|
||||
output_t* __restrict__ y,
|
||||
const __nv_bfloat16* __restrict__ g,
|
||||
const bfloat16* __restrict__ g,
|
||||
const float epsilon,
|
||||
const int rows,
|
||||
const int dim,
|
||||
@@ -403,10 +404,10 @@ void gated_rms_norm
|
||||
if (!small && output_fp16)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const __nv_bfloat16*) x.data_ptr(),
|
||||
(const __nv_bfloat16*) w.data_ptr(),
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const bfloat16*) w.data_ptr(),
|
||||
(half*) y.data_ptr(),
|
||||
(const __nv_bfloat16*) g.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
@@ -415,10 +416,10 @@ void gated_rms_norm
|
||||
else if (!small && output_fp32)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const __nv_bfloat16*) x.data_ptr(),
|
||||
(const __nv_bfloat16*) w.data_ptr(),
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const bfloat16*) w.data_ptr(),
|
||||
(float*) y.data_ptr(),
|
||||
(const __nv_bfloat16*) g.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
@@ -427,10 +428,10 @@ void gated_rms_norm
|
||||
else if (small && output_fp16)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const __nv_bfloat16*) x.data_ptr(),
|
||||
(const __nv_bfloat16*) w.data_ptr(),
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const bfloat16*) w.data_ptr(),
|
||||
(half*) y.data_ptr(),
|
||||
(const __nv_bfloat16*) g.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
@@ -439,10 +440,10 @@ void gated_rms_norm
|
||||
else if (small && output_fp32)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const __nv_bfloat16*) x.data_ptr(),
|
||||
(const __nv_bfloat16*) w.data_ptr(),
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const bfloat16*) w.data_ptr(),
|
||||
(float*) y.data_ptr(),
|
||||
(const __nv_bfloat16*) g.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
|
||||
Reference in New Issue
Block a user