RMSNorm/GatedRMSNorm: Tidy up launch logic with macros and add more dtypes

This commit is contained in:
turboderp
2026-03-02 22:22:47 +01:00
parent 67785fc286
commit 725a75386d

View File

@@ -176,16 +176,16 @@ void rms_norm_kernel
for (int column = t; column < columns; column += NUM_THREADS)
{
float4 x4;
if constexpr (input_fp16) read_half4<true>(x4, ((const half4*) (x + row * dim)) + column);
if constexpr (input_fp32) read_float4(x4, ((const float4*) (x + row * dim)) + column);
if constexpr (input_fp16) read_half4<true>(x4, ((const half4*) (x + row * dim)) + column);
if constexpr (input_fp32) read_float4 (x4, ((const float4*) (x + row * dim)) + column);
if (w)
{
float4 w4;
if constexpr (weight_bf16)
read_bfloat164(w4, ((const bfloat164*) w) + column);
else
read_half4<false>(w4, ((const half4*) w) + column);
if constexpr (weight_bf16) read_bfloat164 (w4, ((const bfloat164*) w) + column);
else read_half4<false>(w4, ((const half4*) w) + column);
if (constant_bias != 0.0f)
{
@@ -223,141 +223,75 @@ void rms_norm
bool span_heads
)
{
const at::cuda::OptionalCUDAGuard device_guard(x.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
if (span_heads)
{
x = x.flatten(-2);
y = y.flatten(-2);
}
const half* w_ptr = (const half*) OPTPTR(w);
TORCH_CHECK_DIV(x, -1, 4);
TORCH_CHECK_SHAPES_FULL(x, y);
bool weight_bf16 = false;
bool weight_fp16 = false;
auto tx = x.dtype();
auto tw = x.dtype(); // intentional, type is if w is None
auto ty = y.dtype();
const half* w_ptr = (const half*) OPTPTR(w);
if (w_ptr)
{
TORCH_CHECK_SHAPES(x, -1, w.value(), 0, 1);
weight_bf16 = w.value().dtype() == at::kBFloat16;
weight_fp16 = w.value().dtype() == at::kHalf;
tw = w.value().dtype();
}
const at::cuda::OptionalCUDAGuard device_guard(x.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
bool input_fp32 = x.dtype() == at::kFloat;
bool output_fp32 = y.dtype() == at::kFloat;
bool input_fp16 = !input_fp32;
bool output_fp16 = !output_fp32;
int rows = 1;
for (int i = 0; i < x.dim() - 1; ++i) rows *= x.size(i);
int dim = x.size(-1);
dim3 blockDim(NUM_THREADS, 1, 1);
dim3 gridDim(rows, 1, 1);
if (input_fp16 && output_fp16 && weight_fp16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) x.data_ptr(),
(const half*) w_ptr,
(half*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
// Launch macro
#define __(_tx, __tx, _tw, __tw, _ty, __ty) \
if (tx == at::_tx && tw == at::_tw && ty == at::_ty) \
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>> \
( \
(const __tx*) x.data_ptr(), \
(const __tw*) w_ptr, \
(__ty*) y.data_ptr(), \
epsilon, \
rows, \
dim, \
constant_bias \
);
else if (input_fp16 && output_fp32 && weight_fp16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) x.data_ptr(),
(const half*) w_ptr,
(float*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp32 && output_fp16 && weight_fp16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) x.data_ptr(),
(const half*) w_ptr,
(half*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp32 && output_fp32 && weight_fp16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) x.data_ptr(),
(const half*) w_ptr,
(float*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp16 && output_fp16 && weight_bf16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) x.data_ptr(),
(const bfloat16*) w_ptr,
(half*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp16 && output_fp32 && weight_bf16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const half*) x.data_ptr(),
(const bfloat16*) w_ptr,
(float*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp32 && output_fp16 && weight_bf16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) x.data_ptr(),
(const bfloat16*) w_ptr,
(half*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (input_fp32 && output_fp32 && weight_bf16)
rms_norm_kernel<<<gridDim, blockDim, 0, stream>>>
(
(const float*) x.data_ptr(),
(const bfloat16*) w_ptr,
(float*) y.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else
TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output")
// x_type________ w_type_____________ y_type_______
__(kHalf, half, kHalf, half, kHalf, half )
else __(kHalf, half, kHalf, half, kFloat, float)
else __(kFloat, float, kHalf, half, kHalf, half )
else __(kFloat, float, kHalf, half, kFloat, float)
else __(kHalf, half, kBFloat16, bfloat16, kHalf, half )
else __(kHalf, half, kBFloat16, bfloat16, kFloat, float)
else __(kFloat, float, kBFloat16, bfloat16, kHalf, half )
else __(kFloat, float, kBFloat16, bfloat16, kFloat, float)
else TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output");
#undef __
cuda_check(cudaPeekAtLastError());
}
template <int num_threads, typename output_t, typename weight_t>
template <int num_threads, typename output_t, typename weight_t, typename gate_t>
__global__ __launch_bounds__(num_threads)
void gated_rms_norm_kernel
(
const bfloat16* __restrict__ x,
const weight_t* __restrict__ w,
output_t* __restrict__ y,
const bfloat16* __restrict__ g,
const gate_t* __restrict__ g,
const float epsilon,
const int rows,
const int dim,
@@ -367,7 +301,7 @@ void gated_rms_norm_kernel
constexpr bool output_fp32 = std::is_same_v<output_t, float>;
constexpr bool output_fp16 = std::is_same_v<output_t, half>;
constexpr bool weight_bf16 = std::is_same_v<weight_t, bfloat16>;
static_assert(output_fp32 || output_fp16, "gated_rms_norm_kernel: output must be float or half type");
constexpr bool gate_fp32 = std::is_same_v<gate_t, float>;
int t = threadIdx.x;
int warp_id = threadIdx.x / 32;
@@ -395,14 +329,12 @@ void gated_rms_norm_kernel
float4 x4;
float4 w4;
float4 g4;
read_bfloat164(x4, ((const bfloat164*) (x + row * dim)) + column);
if constexpr (weight_bf16)
read_bfloat164(w4, ((const bfloat164*) w) + column);
else
read_float4(w4, ((const float4*) w) + column);
read_bfloat164(g4, ((const bfloat164*) (g + row * dim)) + column);
if constexpr (weight_bf16) read_bfloat164(w4, ((const bfloat164*) w) + column);
else read_float4 (w4, ((const float4*) w) + column);
if constexpr (gate_fp32) read_float4 (g4, ((const float4*) (g + row * dim)) + column);
else read_bfloat164(g4, ((const bfloat164*) (g + row * dim)) + column);
if (constant_bias != 0.0f)
{
@@ -442,19 +374,10 @@ void gated_rms_norm
const at::cuda::OptionalCUDAGuard device_guard(x.device());
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
TORCH_CHECK_DTYPE(x, kBFloat16);
// TORCH_CHECK_DTYPE(w, kBFloat16);
TORCH_CHECK_DTYPE(g, kBFloat16);
TORCH_CHECK_DIV(x, -1, 4);
TORCH_CHECK_SHAPES(x, -1, w, 0, 1);
// TORCH_CHECK_SHAPES_FULL(x, y);
TORCH_CHECK_SHAPES_FULL(x, g);
bool output_fp32 = y.dtype() == at::kFloat;
bool output_fp16 = y.dtype() == at::kHalf;
bool weight_bf16 = w.dtype() == at::kBFloat16;
bool weight_fp32 = w.dtype() == at::kFloat;
int rows = 1;
for (int i = 0; i < x.dim() - 1; ++i) rows *= x.size(i);
int dim = x.size(-1);
@@ -464,104 +387,46 @@ void gated_rms_norm
dim3 blockDim(small ? 32 : NUM_THREADS, 1, 1);
dim3 gridDim(rows, 1, 1);
if (!small && output_fp16 && weight_bf16)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const bfloat16*) w.data_ptr(),
(half*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
auto tx = x.dtype();
auto tw = w.dtype();
auto ty = y.dtype();
auto tg = g.dtype();
// Launch macro
#define __(_tx, __tx, _tw, __tw, _ty, __ty, _tg, __tg, _small, __num_threads) \
if (small == _small && tx == at::_tx && tw == at::_tw && ty == at::_ty && tg == at::_tg) \
gated_rms_norm_kernel<__num_threads><<<gridDim, blockDim, 0, stream>>> \
( \
(const __tx*) x.data_ptr(), \
(const __tw*) w.data_ptr(), \
(__ty*) y.data_ptr(), \
(const __tg*) g.data_ptr(), \
epsilon, \
rows, \
dim, \
constant_bias \
);
else if (!small && output_fp32 && weight_bf16)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const bfloat16*) w.data_ptr(),
(float*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (small && output_fp16 && weight_bf16)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const bfloat16*) w.data_ptr(),
(half*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (small && output_fp32 && weight_bf16)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const bfloat16*) w.data_ptr(),
(float*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (!small && output_fp16 && weight_fp32)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const float*) w.data_ptr(),
(half*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (!small && output_fp32 && weight_fp32)
gated_rms_norm_kernel<NUM_THREADS><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const float*) w.data_ptr(),
(float*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (small && output_fp16 && weight_fp32)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const float*) w.data_ptr(),
(half*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else if (small && output_fp32 && weight_fp32)
gated_rms_norm_kernel<32><<<gridDim, blockDim, 0, stream>>>
(
(const bfloat16*) x.data_ptr(),
(const float*) w.data_ptr(),
(float*) y.data_ptr(),
(const bfloat16*) g.data_ptr(),
epsilon,
rows,
dim,
constant_bias
);
else
TORCH_CHECK(false, "gated_rms_norm: Invalid datatypes for input/output, must be half or float")
// x_type_____________ w_type_____________ y_type_______ g_type_____________ small num_threads
__(kBFloat16, bfloat16, kFloat, float, kHalf, half, kBFloat16, bfloat16, true, 32 )
else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kBFloat16, bfloat16, false, NUM_THREADS)
else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kBFloat16, bfloat16, true, 32 )
else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kBFloat16, bfloat16, false, NUM_THREADS)
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kBFloat16, bfloat16, true, 32 )
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kBFloat16, bfloat16, false, NUM_THREADS)
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kBFloat16, bfloat16, true, 32 )
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kBFloat16, bfloat16, false, NUM_THREADS)
else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kFloat, float, true, 32 )
else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kFloat, float, false, NUM_THREADS)
else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kFloat, float, true, 32 )
else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kFloat, float, false, NUM_THREADS)
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kFloat, float, true, 32 )
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kFloat, float, false, NUM_THREADS)
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kFloat, float, true, 32 )
else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kFloat, float, false, NUM_THREADS)
else TORCH_CHECK(false, "gated_rms_norm: Invalid datatypes for input/output");
#undef __
cuda_check(cudaPeekAtLastError());
}