mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
RMSNorm/RoPE kernels: Allow BF16/FP32 norm weights
This commit is contained in:
@@ -130,12 +130,12 @@ __device__ __forceinline__ float _silu(float x)
|
||||
}
|
||||
|
||||
|
||||
template <typename input_t, typename output_t>
|
||||
template <typename input_t, typename output_t, typename weight_t>
|
||||
__global__ __launch_bounds__(NUM_THREADS)
|
||||
void rms_norm_kernel
|
||||
(
|
||||
const input_t* __restrict__ x,
|
||||
const half* __restrict__ w,
|
||||
const weight_t* __restrict__ w,
|
||||
output_t* __restrict__ y,
|
||||
const float epsilon,
|
||||
const int rows,
|
||||
@@ -149,6 +149,7 @@ void rms_norm_kernel
|
||||
constexpr bool output_fp16 = std::is_same_v<output_t, half>;
|
||||
static_assert(input_fp32 || input_fp16, "rms_norm_kernel: input must be float or half type");
|
||||
static_assert(output_fp32 || output_fp16, "rms_norm_kernel: output must be float or half type");
|
||||
constexpr bool weight_bf16 = std::is_same_v<weight_t, bfloat16>;
|
||||
|
||||
int t = threadIdx.x;
|
||||
int warp_id = threadIdx.x / 32;
|
||||
@@ -181,7 +182,11 @@ void rms_norm_kernel
|
||||
if (w)
|
||||
{
|
||||
float4 w4;
|
||||
read_half4<false>(w4, ((const half4*) w) + column);
|
||||
if constexpr (weight_bf16)
|
||||
read_bfloat164(w4, ((const bfloat164*) w) + column);
|
||||
else
|
||||
read_half4<false>(w4, ((const half4*) w) + column);
|
||||
|
||||
if (constant_bias != 0.0f)
|
||||
{
|
||||
w4.x += constant_bias;
|
||||
@@ -224,13 +229,19 @@ void rms_norm
|
||||
y = y.flatten(-2);
|
||||
}
|
||||
|
||||
TORCH_CHECK_DTYPE_OPT(w, kHalf);
|
||||
const half* w_ptr = (const half*) OPTPTR(w);
|
||||
TORCH_CHECK_DIV(x, -1, 4);
|
||||
if (w_ptr)
|
||||
TORCH_CHECK_SHAPES(x, -1, w.value(), 0, 1);
|
||||
TORCH_CHECK_SHAPES_FULL(x, y);
|
||||
|
||||
bool weight_bf16 = false;
|
||||
bool weight_fp16 = false;
|
||||
if (w_ptr)
|
||||
{
|
||||
TORCH_CHECK_SHAPES(x, -1, w.value(), 0, 1);
|
||||
weight_bf16 = w.value().dtype() == at::kBFloat16;
|
||||
weight_fp16 = w.value().dtype() == at::kHalf;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(x.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
@@ -244,7 +255,7 @@ void rms_norm
|
||||
dim3 blockDim(NUM_THREADS, 1, 1);
|
||||
dim3 gridDim(rows, 1, 1);
|
||||
|
||||
if (input_fp16 && output_fp16)
|
||||
if (input_fp16 && output_fp16 && weight_fp16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) x.data_ptr(),
|
||||
@@ -255,7 +266,7 @@ void rms_norm
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp16 && output_fp32)
|
||||
else if (input_fp16 && output_fp32 && weight_fp16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) x.data_ptr(),
|
||||
@@ -266,7 +277,7 @@ void rms_norm
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp32 && output_fp16)
|
||||
else if (input_fp32 && output_fp16 && weight_fp16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) x.data_ptr(),
|
||||
@@ -277,7 +288,7 @@ void rms_norm
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp32 && output_fp32)
|
||||
else if (input_fp32 && output_fp32 && weight_fp16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) x.data_ptr(),
|
||||
@@ -288,19 +299,63 @@ void rms_norm
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp16 && output_fp16 && weight_bf16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) x.data_ptr(),
|
||||
(const bfloat16*) w_ptr,
|
||||
(half*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp16 && output_fp32 && weight_bf16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) x.data_ptr(),
|
||||
(const bfloat16*) w_ptr,
|
||||
(float*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp32 && output_fp16 && weight_bf16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) x.data_ptr(),
|
||||
(const bfloat16*) w_ptr,
|
||||
(half*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp32 && output_fp32 && weight_bf16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) x.data_ptr(),
|
||||
(const bfloat16*) w_ptr,
|
||||
(float*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else
|
||||
TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output, must be half or float")
|
||||
TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output")
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
template <int num_threads, typename output_t>
|
||||
template <int num_threads, typename output_t, typename weight_t>
|
||||
__global__ __launch_bounds__(num_threads)
|
||||
void gated_rms_norm_kernel
|
||||
(
|
||||
const bfloat16* __restrict__ x,
|
||||
const bfloat16* __restrict__ w,
|
||||
const weight_t* __restrict__ w,
|
||||
output_t* __restrict__ y,
|
||||
const bfloat16* __restrict__ g,
|
||||
const float epsilon,
|
||||
@@ -311,6 +366,7 @@ void gated_rms_norm_kernel
|
||||
{
|
||||
constexpr bool output_fp32 = std::is_same_v<output_t, float>;
|
||||
constexpr bool output_fp16 = std::is_same_v<output_t, half>;
|
||||
constexpr bool weight_bf16 = std::is_same_v<weight_t, bfloat16>;
|
||||
static_assert(output_fp32 || output_fp16, "gated_rms_norm_kernel: output must be float or half type");
|
||||
|
||||
int t = threadIdx.x;
|
||||
@@ -340,7 +396,12 @@ void gated_rms_norm_kernel
|
||||
float4 w4;
|
||||
float4 g4;
|
||||
read_bfloat164(x4, ((const bfloat164*) (x + row * dim)) + column);
|
||||
read_bfloat164(w4, ((const bfloat164*) w) + column);
|
||||
|
||||
if constexpr (weight_bf16)
|
||||
read_bfloat164(w4, ((const bfloat164*) w) + column);
|
||||
else
|
||||
read_float4(w4, ((const float4*) w) + column);
|
||||
|
||||
read_bfloat164(g4, ((const bfloat164*) (g + row * dim)) + column);
|
||||
|
||||
if (constant_bias != 0.0f)
|
||||
@@ -382,7 +443,7 @@ void gated_rms_norm
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_DTYPE(x, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(w, kBFloat16);
|
||||
// TORCH_CHECK_DTYPE(w, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(g, kBFloat16);
|
||||
TORCH_CHECK_DIV(x, -1, 4);
|
||||
TORCH_CHECK_SHAPES(x, -1, w, 0, 1);
|
||||
@@ -391,6 +452,8 @@ void gated_rms_norm
|
||||
|
||||
bool output_fp32 = y.dtype() == at::kFloat;
|
||||
bool output_fp16 = y.dtype() == at::kHalf;
|
||||
bool weight_bf16 = w.dtype() == at::kBFloat16;
|
||||
bool weight_fp32 = w.dtype() == at::kFloat;
|
||||
|
||||
int rows = 1;
|
||||
for (int i = 0; i < x.dim() - 1; ++i) rows *= x.size(i);
|
||||
@@ -401,7 +464,7 @@ void gated_rms_norm
|
||||
dim3 blockDim(small ? 32 : NUM_THREADS, 1, 1);
|
||||
dim3 gridDim(rows, 1, 1);
|
||||
|
||||
if (!small && output_fp16)
|
||||
if (!small && output_fp16 && weight_bf16)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
@@ -413,7 +476,7 @@ void gated_rms_norm
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (!small && output_fp32)
|
||||
else if (!small && output_fp32 && weight_bf16)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
@@ -425,7 +488,7 @@ void gated_rms_norm
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (small && output_fp16)
|
||||
else if (small && output_fp16 && weight_bf16)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
@@ -437,7 +500,7 @@ void gated_rms_norm
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (small && output_fp32)
|
||||
else if (small && output_fp32 && weight_bf16)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
@@ -449,6 +512,54 @@ void gated_rms_norm
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (!small && output_fp16 && weight_fp32)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const float*) w.data_ptr(),
|
||||
(half*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (!small && output_fp32 && weight_fp32)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const float*) w.data_ptr(),
|
||||
(float*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (small && output_fp16 && weight_fp32)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const float*) w.data_ptr(),
|
||||
(half*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (small && output_fp32 && weight_fp32)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const float*) w.data_ptr(),
|
||||
(float*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else
|
||||
TORCH_CHECK(false, "gated_rms_norm: Invalid datatypes for input/output, must be half or float")
|
||||
|
||||
|
||||
@@ -13,7 +13,10 @@
|
||||
#define ROPESTYLE_NANOCHAT 3
|
||||
#define MAX_NUM_THREADS 1024
|
||||
|
||||
template <int rope_mode>
|
||||
using bfloat16 = __nv_bfloat16;
|
||||
using bfloat162 = __nv_bfloat162;
|
||||
|
||||
template <int rope_mode, bool norm_bf16>
|
||||
__global__
|
||||
void rope_kernel
|
||||
(
|
||||
@@ -32,8 +35,8 @@ void rope_kernel
|
||||
const uint32_t* __restrict__ positions,
|
||||
const uint32_t* __restrict__ position_ids,
|
||||
float attn_factor,
|
||||
const half* __restrict__ q_norm,
|
||||
const half* __restrict__ k_norm,
|
||||
const void* __restrict__ q_norm,
|
||||
const void* __restrict__ k_norm,
|
||||
const float norm_eps,
|
||||
const float norm_constant_bias,
|
||||
const bool inv_freq_table,
|
||||
@@ -84,7 +87,6 @@ void rope_kernel
|
||||
__shared__ float sums[MAX_NUM_THREADS / 32];
|
||||
|
||||
// Prep
|
||||
half2 norm_constant_bias_h2 = __float2half2_rn(norm_constant_bias);
|
||||
int head_dim_pad = (head_dim + 63) / 64 * 64;
|
||||
|
||||
// Loop over heads
|
||||
@@ -92,7 +94,7 @@ void rope_kernel
|
||||
{
|
||||
const half* g_head_in_ptr;
|
||||
half* g_head_out_ptr;
|
||||
const half* norm_weight;
|
||||
const void* norm_weight;
|
||||
if (head_idx < num_heads_q)
|
||||
{
|
||||
g_head_in_ptr = q + ((batch * seq_len + token_pos) * num_heads_q + head_idx) * head_dim;
|
||||
@@ -165,7 +167,6 @@ void rope_kernel
|
||||
auto apply_norm = [&] ()
|
||||
{
|
||||
half2 *tptr = (half2*)(sh_head + t * 2);
|
||||
half2 *wptr = (half2*)(norm_weight + t * 2);
|
||||
// int lane_id = threadIdx.x % 32;
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int warps = blockDim.x / 32;
|
||||
@@ -185,11 +186,27 @@ void rope_kernel
|
||||
float rmf = rsqrtf(sum / (float) head_dim + norm_eps);
|
||||
v1 *= rmf;
|
||||
v2 *= rmf;
|
||||
v = __floats2half2_rn(v1, v2);
|
||||
|
||||
// Apply weight and store
|
||||
half2 w = __hadd2(*wptr, norm_constant_bias_h2);
|
||||
v = __hmul2(w, v);
|
||||
// Downcast, apply weight and store
|
||||
if constexpr (norm_bf16)
|
||||
{
|
||||
bfloat162 *wptr = (bfloat162*)(((bfloat16*)norm_weight) + t * 2);
|
||||
float2 w = __bfloat1622float2(*wptr);
|
||||
w.x += norm_constant_bias;
|
||||
w.y += norm_constant_bias;
|
||||
v1 *= w.x;
|
||||
v2 *= w.y;
|
||||
v = __floats2half2_rn(v1, v2);
|
||||
}
|
||||
else
|
||||
{
|
||||
half2 norm_constant_bias_h2 = __float2half2_rn(norm_constant_bias);
|
||||
half2 *wptr = (half2*)(((half*)norm_weight) + t * 2);
|
||||
half2 w = __hadd2(*wptr, norm_constant_bias_h2);
|
||||
v = __floats2half2_rn(v1, v2);
|
||||
v = __hmul2(w, v);
|
||||
}
|
||||
|
||||
*tptr = v;
|
||||
__syncthreads();
|
||||
};
|
||||
@@ -334,14 +351,17 @@ void rope
|
||||
TORCH_CHECK(position_ids.value().size(1) == seq_len, "position_ids is incorrect shape");
|
||||
}
|
||||
|
||||
half* q_norm_ptr = (half*) OPTPTR(q_norm);
|
||||
half* k_norm_ptr = (half*) OPTPTR(k_norm);
|
||||
TORCH_CHECK_DTYPE_OPT(q_norm, kHalf);
|
||||
TORCH_CHECK_DTYPE_OPT(k_norm, kHalf);
|
||||
void* q_norm_ptr = (void*) OPTPTR(q_norm);
|
||||
void* k_norm_ptr = (void*) OPTPTR(k_norm);
|
||||
bool norm_fp16 = true;
|
||||
bool norm_bf16 = false;
|
||||
if (q_norm_ptr)
|
||||
{
|
||||
TORCH_CHECK_DIM(q_norm.value(), 1);
|
||||
TORCH_CHECK(q_norm.value().size(0) == head_dim, "q_norm is incorrect size");
|
||||
norm_bf16 = q_norm.value().dtype() == at::kBFloat16;
|
||||
norm_fp16 = q_norm.value().dtype() == at::kHalf;
|
||||
TORCH_CHECK(k_norm.value().dtype() == q_norm.value().dtype(), "q_norm and k_norm must be same dtype");
|
||||
}
|
||||
|
||||
dim3 blocks(seq_len, bsz);
|
||||
@@ -355,9 +375,19 @@ void rope
|
||||
q_norm_ptr, k_norm_ptr, norm_eps, norm_constant_bias, inv_freq_table, inv_freq_stride, \
|
||||
llama_4_scaling_beta, llama_4_scaling_original, post_rope_norm
|
||||
|
||||
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
if (norm_fp16)
|
||||
{
|
||||
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, false><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
}
|
||||
else if (norm_bf16)
|
||||
{
|
||||
if (rope_mode == ROPESTYLE_GPTJ) rope_kernel<ROPESTYLE_GPTJ, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NEOX) rope_kernel<ROPESTYLE_NEOX, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
else if (rope_mode == ROPESTYLE_NANOCHAT) rope_kernel<ROPESTYLE_NANOCHAT, true><<<blocks, threads, 0, stream>>>(ARGS);
|
||||
}
|
||||
else TORCH_CHECK(false, "rope: incorrect norm dtype");
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
@@ -40,7 +40,7 @@ class RMSNorm(Module):
|
||||
def load(self, device: torch.device, **kwargs):
|
||||
self.device = device
|
||||
if not self.unweighted:
|
||||
weight = self.config.stc.get_tensor(f"{self.key}.weight", self.device, float2half = True)
|
||||
weight = self.config.stc.get_tensor(f"{self.key}.weight", self.device, float2half = True, allow_bf16 = True)
|
||||
self._numel = weight.numel()
|
||||
self.weight = nn.Parameter(weight, requires_grad = False)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user