mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-03-15 00:07:24 +00:00
RMSNorm/GatedRMSNorm: Tidy up launch logic with macros and add more dtypes
This commit is contained in:
@@ -176,16 +176,16 @@ void rms_norm_kernel
|
||||
for (int column = t; column < columns; column += NUM_THREADS)
|
||||
{
|
||||
float4 x4;
|
||||
if constexpr (input_fp16) read_half4<true>(x4, ((const half4*) (x + row * dim)) + column);
|
||||
if constexpr (input_fp32) read_float4(x4, ((const float4*) (x + row * dim)) + column);
|
||||
|
||||
if constexpr (input_fp16) read_half4<true>(x4, ((const half4*) (x + row * dim)) + column);
|
||||
if constexpr (input_fp32) read_float4 (x4, ((const float4*) (x + row * dim)) + column);
|
||||
|
||||
if (w)
|
||||
{
|
||||
float4 w4;
|
||||
if constexpr (weight_bf16)
|
||||
read_bfloat164(w4, ((const bfloat164*) w) + column);
|
||||
else
|
||||
read_half4<false>(w4, ((const half4*) w) + column);
|
||||
|
||||
if constexpr (weight_bf16) read_bfloat164 (w4, ((const bfloat164*) w) + column);
|
||||
else read_half4<false>(w4, ((const half4*) w) + column);
|
||||
|
||||
if (constant_bias != 0.0f)
|
||||
{
|
||||
@@ -223,141 +223,75 @@ void rms_norm
|
||||
bool span_heads
|
||||
)
|
||||
{
|
||||
const at::cuda::OptionalCUDAGuard device_guard(x.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
if (span_heads)
|
||||
{
|
||||
x = x.flatten(-2);
|
||||
y = y.flatten(-2);
|
||||
}
|
||||
|
||||
const half* w_ptr = (const half*) OPTPTR(w);
|
||||
TORCH_CHECK_DIV(x, -1, 4);
|
||||
TORCH_CHECK_SHAPES_FULL(x, y);
|
||||
|
||||
bool weight_bf16 = false;
|
||||
bool weight_fp16 = false;
|
||||
auto tx = x.dtype();
|
||||
auto tw = x.dtype(); // intentional, type is if w is None
|
||||
auto ty = y.dtype();
|
||||
|
||||
const half* w_ptr = (const half*) OPTPTR(w);
|
||||
if (w_ptr)
|
||||
{
|
||||
TORCH_CHECK_SHAPES(x, -1, w.value(), 0, 1);
|
||||
weight_bf16 = w.value().dtype() == at::kBFloat16;
|
||||
weight_fp16 = w.value().dtype() == at::kHalf;
|
||||
tw = w.value().dtype();
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(x.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
bool input_fp32 = x.dtype() == at::kFloat;
|
||||
bool output_fp32 = y.dtype() == at::kFloat;
|
||||
bool input_fp16 = !input_fp32;
|
||||
bool output_fp16 = !output_fp32;
|
||||
int rows = 1;
|
||||
for (int i = 0; i < x.dim() - 1; ++i) rows *= x.size(i);
|
||||
int dim = x.size(-1);
|
||||
|
||||
dim3 blockDim(NUM_THREADS, 1, 1);
|
||||
dim3 gridDim(rows, 1, 1);
|
||||
|
||||
if (input_fp16 && output_fp16 && weight_fp16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) x.data_ptr(),
|
||||
(const half*) w_ptr,
|
||||
(half*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
// Launch macro
|
||||
#define __(_tx, __tx, _tw, __tw, _ty, __ty) \
|
||||
if (tx == at::_tx && tw == at::_tw && ty == at::_ty) \
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>> \
|
||||
( \
|
||||
(const __tx*) x.data_ptr(), \
|
||||
(const __tw*) w_ptr, \
|
||||
(__ty*) y.data_ptr(), \
|
||||
epsilon, \
|
||||
rows, \
|
||||
dim, \
|
||||
constant_bias \
|
||||
);
|
||||
else if (input_fp16 && output_fp32 && weight_fp16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) x.data_ptr(),
|
||||
(const half*) w_ptr,
|
||||
(float*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp32 && output_fp16 && weight_fp16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) x.data_ptr(),
|
||||
(const half*) w_ptr,
|
||||
(half*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp32 && output_fp32 && weight_fp16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) x.data_ptr(),
|
||||
(const half*) w_ptr,
|
||||
(float*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp16 && output_fp16 && weight_bf16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) x.data_ptr(),
|
||||
(const bfloat16*) w_ptr,
|
||||
(half*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp16 && output_fp32 && weight_bf16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const half*) x.data_ptr(),
|
||||
(const bfloat16*) w_ptr,
|
||||
(float*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp32 && output_fp16 && weight_bf16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) x.data_ptr(),
|
||||
(const bfloat16*) w_ptr,
|
||||
(half*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (input_fp32 && output_fp32 && weight_bf16)
|
||||
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const float*) x.data_ptr(),
|
||||
(const bfloat16*) w_ptr,
|
||||
(float*) y.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else
|
||||
TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output")
|
||||
|
||||
// x_type________ w_type_____________ y_type_______
|
||||
__(kHalf, half, kHalf, half, kHalf, half )
|
||||
else __(kHalf, half, kHalf, half, kFloat, float)
|
||||
else __(kFloat, float, kHalf, half, kHalf, half )
|
||||
else __(kFloat, float, kHalf, half, kFloat, float)
|
||||
else __(kHalf, half, kBFloat16, bfloat16, kHalf, half )
|
||||
else __(kHalf, half, kBFloat16, bfloat16, kFloat, float)
|
||||
else __(kFloat, float, kBFloat16, bfloat16, kHalf, half )
|
||||
else __(kFloat, float, kBFloat16, bfloat16, kFloat, float)
|
||||
|
||||
else TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output");
|
||||
#undef __
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
template <int num_threads, typename output_t, typename weight_t>
|
||||
template <int num_threads, typename output_t, typename weight_t, typename gate_t>
|
||||
__global__ __launch_bounds__(num_threads)
|
||||
void gated_rms_norm_kernel
|
||||
(
|
||||
const bfloat16* __restrict__ x,
|
||||
const weight_t* __restrict__ w,
|
||||
output_t* __restrict__ y,
|
||||
const bfloat16* __restrict__ g,
|
||||
const gate_t* __restrict__ g,
|
||||
const float epsilon,
|
||||
const int rows,
|
||||
const int dim,
|
||||
@@ -367,7 +301,7 @@ void gated_rms_norm_kernel
|
||||
constexpr bool output_fp32 = std::is_same_v<output_t, float>;
|
||||
constexpr bool output_fp16 = std::is_same_v<output_t, half>;
|
||||
constexpr bool weight_bf16 = std::is_same_v<weight_t, bfloat16>;
|
||||
static_assert(output_fp32 || output_fp16, "gated_rms_norm_kernel: output must be float or half type");
|
||||
constexpr bool gate_fp32 = std::is_same_v<gate_t, float>;
|
||||
|
||||
int t = threadIdx.x;
|
||||
int warp_id = threadIdx.x / 32;
|
||||
@@ -395,14 +329,12 @@ void gated_rms_norm_kernel
|
||||
float4 x4;
|
||||
float4 w4;
|
||||
float4 g4;
|
||||
|
||||
read_bfloat164(x4, ((const bfloat164*) (x + row * dim)) + column);
|
||||
|
||||
if constexpr (weight_bf16)
|
||||
read_bfloat164(w4, ((const bfloat164*) w) + column);
|
||||
else
|
||||
read_float4(w4, ((const float4*) w) + column);
|
||||
|
||||
read_bfloat164(g4, ((const bfloat164*) (g + row * dim)) + column);
|
||||
if constexpr (weight_bf16) read_bfloat164(w4, ((const bfloat164*) w) + column);
|
||||
else read_float4 (w4, ((const float4*) w) + column);
|
||||
if constexpr (gate_fp32) read_float4 (g4, ((const float4*) (g + row * dim)) + column);
|
||||
else read_bfloat164(g4, ((const bfloat164*) (g + row * dim)) + column);
|
||||
|
||||
if (constant_bias != 0.0f)
|
||||
{
|
||||
@@ -442,19 +374,10 @@ void gated_rms_norm
|
||||
const at::cuda::OptionalCUDAGuard device_guard(x.device());
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
|
||||
TORCH_CHECK_DTYPE(x, kBFloat16);
|
||||
// TORCH_CHECK_DTYPE(w, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(g, kBFloat16);
|
||||
TORCH_CHECK_DIV(x, -1, 4);
|
||||
TORCH_CHECK_SHAPES(x, -1, w, 0, 1);
|
||||
// TORCH_CHECK_SHAPES_FULL(x, y);
|
||||
TORCH_CHECK_SHAPES_FULL(x, g);
|
||||
|
||||
bool output_fp32 = y.dtype() == at::kFloat;
|
||||
bool output_fp16 = y.dtype() == at::kHalf;
|
||||
bool weight_bf16 = w.dtype() == at::kBFloat16;
|
||||
bool weight_fp32 = w.dtype() == at::kFloat;
|
||||
|
||||
int rows = 1;
|
||||
for (int i = 0; i < x.dim() - 1; ++i) rows *= x.size(i);
|
||||
int dim = x.size(-1);
|
||||
@@ -464,104 +387,46 @@ void gated_rms_norm
|
||||
dim3 blockDim(small ? 32 : NUM_THREADS, 1, 1);
|
||||
dim3 gridDim(rows, 1, 1);
|
||||
|
||||
if (!small && output_fp16 && weight_bf16)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const bfloat16*) w.data_ptr(),
|
||||
(half*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
auto tx = x.dtype();
|
||||
auto tw = w.dtype();
|
||||
auto ty = y.dtype();
|
||||
auto tg = g.dtype();
|
||||
|
||||
// Launch macro
|
||||
#define __(_tx, __tx, _tw, __tw, _ty, __ty, _tg, __tg, _small, __num_threads) \
|
||||
if (small == _small && tx == at::_tx && tw == at::_tw && ty == at::_ty && tg == at::_tg) \
|
||||
gated_rms_norm_kernel<__num_threads><<<gridDim, blockDim, 0, stream>>> \
|
||||
( \
|
||||
(const __tx*) x.data_ptr(), \
|
||||
(const __tw*) w.data_ptr(), \
|
||||
(__ty*) y.data_ptr(), \
|
||||
(const __tg*) g.data_ptr(), \
|
||||
epsilon, \
|
||||
rows, \
|
||||
dim, \
|
||||
constant_bias \
|
||||
);
|
||||
else if (!small && output_fp32 && weight_bf16)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const bfloat16*) w.data_ptr(),
|
||||
(float*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (small && output_fp16 && weight_bf16)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const bfloat16*) w.data_ptr(),
|
||||
(half*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (small && output_fp32 && weight_bf16)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const bfloat16*) w.data_ptr(),
|
||||
(float*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (!small && output_fp16 && weight_fp32)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const float*) w.data_ptr(),
|
||||
(half*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (!small && output_fp32 && weight_fp32)
|
||||
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const float*) w.data_ptr(),
|
||||
(float*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (small && output_fp16 && weight_fp32)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const float*) w.data_ptr(),
|
||||
(half*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else if (small && output_fp32 && weight_fp32)
|
||||
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
|
||||
(
|
||||
(const bfloat16*) x.data_ptr(),
|
||||
(const float*) w.data_ptr(),
|
||||
(float*) y.data_ptr(),
|
||||
(const bfloat16*) g.data_ptr(),
|
||||
epsilon,
|
||||
rows,
|
||||
dim,
|
||||
constant_bias
|
||||
);
|
||||
else
|
||||
TORCH_CHECK(false, "gated_rms_norm: Invalid datatypes for input/output, must be half or float")
|
||||
|
||||
// x_type_____________ w_type_____________ y_type_______ g_type_____________ small num_threads
|
||||
__(kBFloat16, bfloat16, kFloat, float, kHalf, half, kBFloat16, bfloat16, true, 32 )
|
||||
else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kBFloat16, bfloat16, false, NUM_THREADS)
|
||||
else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kBFloat16, bfloat16, true, 32 )
|
||||
else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kBFloat16, bfloat16, false, NUM_THREADS)
|
||||
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kBFloat16, bfloat16, true, 32 )
|
||||
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kBFloat16, bfloat16, false, NUM_THREADS)
|
||||
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kBFloat16, bfloat16, true, 32 )
|
||||
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kBFloat16, bfloat16, false, NUM_THREADS)
|
||||
else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kFloat, float, true, 32 )
|
||||
else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kFloat, float, false, NUM_THREADS)
|
||||
else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kFloat, float, true, 32 )
|
||||
else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kFloat, float, false, NUM_THREADS)
|
||||
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kFloat, float, true, 32 )
|
||||
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kFloat, float, false, NUM_THREADS)
|
||||
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kFloat, float, true, 32 )
|
||||
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kFloat, float, false, NUM_THREADS)
|
||||
|
||||
else TORCH_CHECK(false, "gated_rms_norm: Invalid datatypes for input/output");
|
||||
#undef __
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
Reference in New Issue
Block a user