RMSNorm/RoPE kernels: Allow BF16/FP32 norm weights

This commit is contained in:
turboderp
2026-03-02 00:54:15 +01:00
parent e2f4198406
commit 93695e9a7d
3 changed files with 178 additions and 37 deletions

View File

@@ -130,12 +130,12 @@ __device__ __forceinline__ float _silu(float x)
}
template <typename input_t, typename output_t>
template <typename input_t, typename output_t, typename weight_t>
__global__ __launch_bounds__(NUM_THREADS)
void rms_norm_kernel
(
const input_t* __restrict__ x,
const half* __restrict__ w,
const weight_t* __restrict__ w,
output_t* __restrict__ y,
const float epsilon,
const int rows,
@@ -149,6 +149,7 @@ void rms_norm_kernel
constexpr bool output_fp16 = std::is_same_v<output_t, half>;
static_assert(input_fp32 || input_fp16, "rms_norm_kernel: input must be float or half type");
static_assert(output_fp32 || output_fp16, "rms_norm_kernel: output must be float or half type");
constexpr bool weight_bf16 = std::is_same_v<weight_t, bfloat16>;
int t = threadIdx.x;
int warp_id = threadIdx.x / 32;
@@ -181,7 +182,11 @@ void rms_norm_kernel
if (w)
{
float4 w4;
read_half4<false>(w4, ((const half4*) w) + column);
if constexpr (weight_bf16)
read_bfloat164(w4, ((const bfloat164*) w) + column);
else
read_half4<false>(w4, ((const half4*) w) + column);
if (constant_bias != 0.0f)
{
w4.x += constant_bias;
@@ -224,13 +229,19 @@ void rms_norm
y = y.flatten(-2);
}
TORCH_CHECK_DTYPE_OPT(w, kHalf);
const half* w_ptr = (const half*) OPTPTR(w);
TORCH_CHECK_DIV(x, -1, 4);
if (w_ptr)
TORCH_CHECK_SHAPES(x, -1, w.value(), 0, 1);
TORCH_CHECK_SHAPES_FULL(x, y);
bool weight_bf16 = false;
bool weight_fp16 = false;
if (w_ptr)
{
TORCH_CHECK_SHAPES(x, -1, w.value(), 0, 1);
weight_bf16 = w.value().dtype() == at::kBFloat16;
weight_fp16 = w.value().dtype() == at::kHalf;
}
const at::cuda::OptionalCUDAGuard device_guard(x.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
@@ -244,7 +255,7 @@ void rms_norm
dim3 blockDim(NUM_THREADS, 1, 1);
dim3 gridDim(rows, 1, 1);
if (input_fp16 && output_fp16)
if (input_fp16 && output_fp16 && weight_fp16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) x.data_ptr(),
@@ -255,7 +266,7 @@ void rms_norm
dim,
constant_bias
);
else if (input_fp16 && output_fp32)
else if (input_fp16 && output_fp32 && weight_fp16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) x.data_ptr(),
@@ -266,7 +277,7 @@ void rms_norm
dim,
constant_bias
);
else if (input_fp32 && output_fp16)
else if (input_fp32 && output_fp16 && weight_fp16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) x.data_ptr(),
@@ -277,7 +288,7 @@ void rms_norm
dim,
constant_bias
);
else if (input_fp32 && output_fp32)
else if (input_fp32 && output_fp32 && weight_fp16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) x.data_ptr(),
@@ -288,19 +299,63 @@ void rms_norm
dim,
constant_bias
);
else if (input_fp16 && output_fp16 && weight_bf16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) x.data_ptr(),
(const bfloat16*) w_ptr,
(half*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp16 && output_fp32 && weight_bf16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) x.data_ptr(),
(const bfloat16*) w_ptr,
(float*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp32 && output_fp16 && weight_bf16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) x.data_ptr(),
(const bfloat16*) w_ptr,
(half*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp32 && output_fp32 && weight_bf16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) x.data_ptr(),
(const bfloat16*) w_ptr,
(float*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else
TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output, must be half or float")
TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output")
cuda_check(cudaPeekAtLastError());
}
template <int num_threads, typename output_t>
template <int num_threads, typename output_t, typename weight_t>
__global__ __launch_bounds__(num_threads)
void gated_rms_norm_kernel
(
const bfloat16* __restrict__ x,
const bfloat16* __restrict__ w,
const weight_t* __restrict__ w,
output_t* __restrict__ y,
const bfloat16* __restrict__ g,
const float epsilon,
@@ -311,6 +366,7 @@ void gated_rms_norm_kernel
{
constexpr bool output_fp32 = std::is_same_v<output_t, float>;
constexpr bool output_fp16 = std::is_same_v<output_t, half>;
constexpr bool weight_bf16 = std::is_same_v<weight_t, bfloat16>;
static_assert(output_fp32 || output_fp16, "gated_rms_norm_kernel: output must be float or half type");
int t = threadIdx.x;
@@ -340,7 +396,12 @@ void gated_rms_norm_kernel
float4 w4;
float4 g4;
read_bfloat164(x4, ((const bfloat164*) (x + row * dim)) + column);
read_bfloat164(w4, ((const bfloat164*) w) + column);
if constexpr (weight_bf16)
read_bfloat164(w4, ((const bfloat164*) w) + column);
else
read_float4(w4, ((const float4*) w) + column);
read_bfloat164(g4, ((const bfloat164*) (g + row * dim)) + column);
if (constant_bias != 0.0f)
@@ -382,7 +443,7 @@ void gated_rms_norm
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_DTYPE(x, kBFloat16);
TORCH_CHECK_DTYPE(w, kBFloat16);
// TORCH_CHECK_DTYPE(w, kBFloat16);
TORCH_CHECK_DTYPE(g, kBFloat16);
TORCH_CHECK_DIV(x, -1, 4);
TORCH_CHECK_SHAPES(x, -1, w, 0, 1);
@@ -391,6 +452,8 @@ void gated_rms_norm
bool output_fp32 = y.dtype() == at::kFloat;
bool output_fp16 = y.dtype() == at::kHalf;
bool weight_bf16 = w.dtype() == at::kBFloat16;
bool weight_fp32 = w.dtype() == at::kFloat;
int rows = 1;
for (int i = 0; i < x.dim() - 1; ++i) rows *= x.size(i);
@@ -401,7 +464,7 @@ void gated_rms_norm
dim3 blockDim(small ? 32 : NUM_THREADS, 1, 1);
dim3 gridDim(rows, 1, 1);
if (!small && output_fp16)
if (!small && output_fp16 && weight_bf16)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
@@ -413,7 +476,7 @@ void gated_rms_norm
dim,
constant_bias
);
else if (!small && output_fp32)
else if (!small && output_fp32 && weight_bf16)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
@@ -425,7 +488,7 @@ void gated_rms_norm
dim,
constant_bias
);
else if (small && output_fp16)
else if (small && output_fp16 && weight_bf16)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
@@ -437,7 +500,7 @@ void gated_rms_norm
dim,
constant_bias
);
else if (small && output_fp32)
else if (small && output_fp32 && weight_bf16)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
@@ -449,6 +512,54 @@ void gated_rms_norm
dim,
constant_bias
);
else if (!small && output_fp16 && weight_fp32)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const float*) w.data_ptr(),
(half*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (!small && output_fp32 && weight_fp32)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const float*) w.data_ptr(),
(float*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (small && output_fp16 && weight_fp32)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const float*) w.data_ptr(),
(half*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (small && output_fp32 && weight_fp32)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const float*) w.data_ptr(),
(float*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else
TORCH_CHECK(false, "gated_rms_norm: Invalid datatypes for input/output, must be half or float")

View File

@@ -13,7 +13,10 @@
#define ROPESTYLE_NANOCHAT 3
#define MAX_NUM_THREADS 1024
template <int rope_mode>
using bfloat16 = __nv_bfloat16;
using bfloat162 = __nv_bfloat162;
template <int rope_mode, bool norm_bf16>
__global__
void rope_kernel
(
@@ -32,8 +35,8 @@ void rope_kernel
const uint32_t* __restrict__ positions,
const uint32_t* __restrict__ position_ids,
float attn_factor,
const half* __restrict__ q_norm,
const half* __restrict__ k_norm,
const void* __restrict__ q_norm,
const void* __restrict__ k_norm,
const float norm_eps,
const float norm_constant_bias,
const bool inv_freq_table,
@@ -84,7 +87,6 @@ void rope_kernel
__shared__ float sums[MAX_NUM_THREADS / 32];
// Prep
half2 norm_constant_bias_h2 = __float2half2_rn(norm_constant_bias);
int head_dim_pad = (head_dim + 63) / 64 * 64;
// Loop over heads
@@ -92,7 +94,7 @@ void rope_kernel
{
const half* g_head_in_ptr;
half* g_head_out_ptr;
const half* norm_weight;
const void* norm_weight;
if (head_idx < num_heads_q)
{
g_head_in_ptr = q + ((batch * seq_len + token_pos) * num_heads_q + head_idx) * head_dim;
@@ -165,7 +167,6 @@ void rope_kernel
auto apply_norm = [&] ()
{
half2 *tptr = (half2*)(sh_head + t * 2);
half2 *wptr = (half2*)(norm_weight + t * 2);
// int lane_id = threadIdx.x % 32;
int warp_id = threadIdx.x / 32;
int warps = blockDim.x / 32;
@@ -185,11 +186,27 @@ void rope_kernel
float rmf = rsqrtf(sum / (float) head_dim + norm_eps);
v1 *= rmf;
v2 *= rmf;
v = __floats2half2_rn(v1, v2);
// Apply weight and store
half2 w = __hadd2(*wptr, norm_constant_bias_h2);
v = __hmul2(w, v);
// Downcast, apply weight and store
if constexpr (norm_bf16)
{
bfloat162 *wptr = (bfloat162*)(((bfloat16*)norm_weight) + t * 2);
float2 w = __bfloat1622float2(*wptr);
w.x += norm_constant_bias;
w.y += norm_constant_bias;
v1 *= w.x;
v2 *= w.y;
v = __floats2half2_rn(v1, v2);
}
else
{
half2 norm_constant_bias_h2 = __float2half2_rn(norm_constant_bias);
half2 *wptr = (half2*)(((half*)norm_weight) + t * 2);
half2 w = __hadd2(*wptr, norm_constant_bias_h2);
v = __floats2half2_rn(v1, v2);
v = __hmul2(w, v);
}
*tptr = v;
__syncthreads();
};
@@ -334,14 +351,17 @@ void rope
TORCH_CHECK(position_ids.value().size(1) == seq_len, "position_ids is incorrect shape");
}
half* q_norm_ptr = (half*) OPTPTR(q_norm);
half* k_norm_ptr = (half*) OPTPTR(k_norm);
TORCH_CHECK_DTYPE_OPT(q_norm, kHalf);
TORCH_CHECK_DTYPE_OPT(k_norm, kHalf);
void* q_norm_ptr = (void*) OPTPTR(q_norm);
void* k_norm_ptr = (void*) OPTPTR(k_norm);
bool norm_fp16 = true;
bool norm_bf16 = false;
if (q_norm_ptr)
{
TORCH_CHECK_DIM(q_norm.value(), 1);
TORCH_CHECK(q_norm.value().size(0) == head_dim, "q_norm is incorrect size");
norm_bf16 = q_norm.value().dtype() == at::kBFloat16;
norm_fp16 = q_norm.value().dtype() == at::kHalf;
TORCH_CHECK(k_norm.value().dtype() == q_norm.value().dtype(), "q_norm and k_norm must be same dtype");
}
dim3 blocks(seq_len, bsz);
@@ -355,9 +375,19 @@ void rope
q_norm_ptr, k_norm_ptr, norm_eps, norm_constant_bias, inv_freq_table, inv_freq_stride, \
llama_4_scaling_beta, llama_4_scaling_original, post_rope_norm
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT><<<blocks, threads, 0, stream>>>(ARGS);
if (norm_fp16)
{
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, false><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, false><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, false><<<blocks, threads, 0, stream>>>(ARGS);
}
else if (norm_bf16)
{
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, true><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, true><<<blocks, threads, 0, stream>>>(ARGS);
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, true><<<blocks, threads, 0, stream>>>(ARGS);
}
else TORCH_CHECK(false, "rope: incorrect norm dtype");
cuda_check(cudaPeekAtLastError());
}

View File

@@ -40,7 +40,7 @@ class RMSNorm(Module):
def load(self, device: torch.device, **kwargs):
self.device = device
if not self.unweighted:
weight = self.config.stc.get_tensor(f"{self.key}.weight", self.device, float2half = True)
weight = self.config.stc.get_tensor(f"{self.key}.weight", self.device, float2half = True, allow_bf16 = True)
self._numel = weight.numel()
self.weight = nn.Parameter(weight, requires_grad = False)
else: