Alias __nv_bfloat16 -> bfloat16

This commit is contained in:
turboderp
2026-02-17 21:24:41 +01:00
parent b2b6f37e12
commit ed5bad7235

View File

@@ -6,6 +6,7 @@
#include "util.cuh"
#define NUM_THREADS 1024
using bfloat16 = __nv_bfloat16;
template <int num_threads>
__device__ inline float reduce(float sum, int warp_id, int lane_id)
@@ -298,10 +299,10 @@ template <int num_threads, typename output_t>
__global__ __launch_bounds__(num_threads)
void gated_rms_norm_kernel
(
const __nv_bfloat16* __restrict__ x,
const __nv_bfloat16* __restrict__ w,
const bfloat16* __restrict__ x,
const bfloat16* __restrict__ w,
output_t* __restrict__ y,
const __nv_bfloat16* __restrict__ g,
const bfloat16* __restrict__ g,
const float epsilon,
const int rows,
const int dim,
@@ -403,10 +404,10 @@ void gated_rms_norm
if (!small && output_fp16)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const __nv_bfloat16*) x.data_ptr(),
(const __nv_bfloat16*) w.data_ptr(),
(const bfloat16*) x.data_ptr(),
(const bfloat16*) w.data_ptr(),
(half*) y.data_ptr(),
(const __nv_bfloat16*) g.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
@@ -415,10 +416,10 @@ void gated_rms_norm
else if (!small && output_fp32)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const __nv_bfloat16*) x.data_ptr(),
(const __nv_bfloat16*) w.data_ptr(),
(const bfloat16*) x.data_ptr(),
(const bfloat16*) w.data_ptr(),
(float*) y.data_ptr(),
(const __nv_bfloat16*) g.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
@@ -427,10 +428,10 @@ void gated_rms_norm
else if (small && output_fp16)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const __nv_bfloat16*) x.data_ptr(),
(const __nv_bfloat16*) w.data_ptr(),
(const bfloat16*) x.data_ptr(),
(const bfloat16*) w.data_ptr(),
(half*) y.data_ptr(),
(const __nv_bfloat16*) g.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
@@ -439,10 +440,10 @@ void gated_rms_norm
else if (small && output_fp32)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const __nv_bfloat16*) x.data_ptr(),
(const __nv_bfloat16*) w.data_ptr(),
(const bfloat16*) x.data_ptr(),
(const bfloat16*) w.data_ptr(),
(float*) y.data_ptr(),
(const __nv_bfloat16*) g.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,