diff --git a/exllamav3/exllamav3_ext/norm.cu b/exllamav3/exllamav3_ext/norm.cu index 08f97de..dde8556 100644 --- a/exllamav3/exllamav3_ext/norm.cu +++ b/exllamav3/exllamav3_ext/norm.cu @@ -176,16 +176,16 @@ void rms_norm_kernel for (int column = t; column < columns; column += NUM_THREADS) { float4 x4; - if constexpr (input_fp16) read_half4(x4, ((const half4*) (x + row * dim)) + column); - if constexpr (input_fp32) read_float4(x4, ((const float4*) (x + row * dim)) + column); + + if constexpr (input_fp16) read_half4(x4, ((const half4*) (x + row * dim)) + column); + if constexpr (input_fp32) read_float4 (x4, ((const float4*) (x + row * dim)) + column); if (w) { float4 w4; - if constexpr (weight_bf16) - read_bfloat164(w4, ((const bfloat164*) w) + column); - else - read_half4(w4, ((const half4*) w) + column); + + if constexpr (weight_bf16) read_bfloat164 (w4, ((const bfloat164*) w) + column); + else read_half4(w4, ((const half4*) w) + column); if (constant_bias != 0.0f) { @@ -223,141 +223,75 @@ void rms_norm bool span_heads ) { + const at::cuda::OptionalCUDAGuard device_guard(x.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + if (span_heads) { x = x.flatten(-2); y = y.flatten(-2); } - const half* w_ptr = (const half*) OPTPTR(w); TORCH_CHECK_DIV(x, -1, 4); TORCH_CHECK_SHAPES_FULL(x, y); - bool weight_bf16 = false; - bool weight_fp16 = false; + auto tx = x.dtype(); + auto tw = x.dtype(); // intentional, type is if w is None + auto ty = y.dtype(); + + const half* w_ptr = (const half*) OPTPTR(w); if (w_ptr) { TORCH_CHECK_SHAPES(x, -1, w.value(), 0, 1); - weight_bf16 = w.value().dtype() == at::kBFloat16; - weight_fp16 = w.value().dtype() == at::kHalf; + tw = w.value().dtype(); } - const at::cuda::OptionalCUDAGuard device_guard(x.device()); - cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - - bool input_fp32 = x.dtype() == at::kFloat; - bool output_fp32 = y.dtype() == at::kFloat; - bool input_fp16 = !input_fp32; - bool output_fp16 = !output_fp32; int rows = 1; for (int i = 0; i < x.dim() - 1; ++i) rows *= x.size(i); int dim = x.size(-1); + dim3 blockDim(NUM_THREADS, 1, 1); dim3 gridDim(rows, 1, 1); - if (input_fp16 && output_fp16 && weight_fp16) - rms_norm_kernel<<>> - ( - (const half*) x.data_ptr(), - (const half*) w_ptr, - (half*) y.data_ptr(), - epsilon, - rows, - dim, - constant_bias + // Launch macro + #define __(_tx, __tx, _tw, __tw, _ty, __ty) \ + if (tx == at::_tx && tw == at::_tw && ty == at::_ty) \ + rms_norm_kernel<<>> \ + ( \ + (const __tx*) x.data_ptr(), \ + (const __tw*) w_ptr, \ + (__ty*) y.data_ptr(), \ + epsilon, \ + rows, \ + dim, \ + constant_bias \ ); - else if (input_fp16 && output_fp32 && weight_fp16) - rms_norm_kernel<<>> - ( - (const half*) x.data_ptr(), - (const half*) w_ptr, - (float*) y.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (input_fp32 && output_fp16 && weight_fp16) - rms_norm_kernel<<>> - ( - (const float*) x.data_ptr(), - (const half*) w_ptr, - (half*) y.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (input_fp32 && output_fp32 && weight_fp16) - rms_norm_kernel<<>> - ( - (const float*) x.data_ptr(), - (const half*) w_ptr, - (float*) y.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (input_fp16 && output_fp16 && weight_bf16) - rms_norm_kernel<<>> - ( - (const half*) x.data_ptr(), - (const bfloat16*) w_ptr, - (half*) y.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (input_fp16 && output_fp32 && weight_bf16) - rms_norm_kernel<<>> - ( - (const half*) x.data_ptr(), - (const bfloat16*) w_ptr, - (float*) y.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (input_fp32 && output_fp16 && weight_bf16) - rms_norm_kernel<<>> - ( - (const float*) x.data_ptr(), - (const bfloat16*) w_ptr, - (half*) y.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (input_fp32 && output_fp32 && weight_bf16) - rms_norm_kernel<<>> - ( - (const float*) x.data_ptr(), - (const bfloat16*) w_ptr, - (float*) y.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else - TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output") + + // x_type________ w_type_____________ y_type_______ + __(kHalf, half, kHalf, half, kHalf, half ) + else __(kHalf, half, kHalf, half, kFloat, float) + else __(kFloat, float, kHalf, half, kHalf, half ) + else __(kFloat, float, kHalf, half, kFloat, float) + else __(kHalf, half, kBFloat16, bfloat16, kHalf, half ) + else __(kHalf, half, kBFloat16, bfloat16, kFloat, float) + else __(kFloat, float, kBFloat16, bfloat16, kHalf, half ) + else __(kFloat, float, kBFloat16, bfloat16, kFloat, float) + + else TORCH_CHECK(false, "rms_norm: Invalid datatypes for input/output"); + #undef __ cuda_check(cudaPeekAtLastError()); } -template +template __global__ __launch_bounds__(num_threads) void gated_rms_norm_kernel ( const bfloat16* __restrict__ x, const weight_t* __restrict__ w, output_t* __restrict__ y, - const bfloat16* __restrict__ g, + const gate_t* __restrict__ g, const float epsilon, const int rows, const int dim, @@ -367,7 +301,7 @@ void gated_rms_norm_kernel constexpr bool output_fp32 = std::is_same_v; constexpr bool output_fp16 = std::is_same_v; constexpr bool weight_bf16 = std::is_same_v; - static_assert(output_fp32 || output_fp16, "gated_rms_norm_kernel: output must be float or half type"); + constexpr bool gate_fp32 = std::is_same_v; int t = threadIdx.x; int warp_id = threadIdx.x / 32; @@ -395,14 +329,12 @@ void gated_rms_norm_kernel float4 x4; float4 w4; float4 g4; + read_bfloat164(x4, ((const bfloat164*) (x + row * dim)) + column); - - if constexpr (weight_bf16) - read_bfloat164(w4, ((const bfloat164*) w) + column); - else - read_float4(w4, ((const float4*) w) + column); - - read_bfloat164(g4, ((const bfloat164*) (g + row * dim)) + column); + if constexpr (weight_bf16) read_bfloat164(w4, ((const bfloat164*) w) + column); + else read_float4 (w4, ((const float4*) w) + column); + if constexpr (gate_fp32) read_float4 (g4, ((const float4*) (g + row * dim)) + column); + else read_bfloat164(g4, ((const bfloat164*) (g + row * dim)) + column); if (constant_bias != 0.0f) { @@ -442,19 +374,10 @@ void gated_rms_norm const at::cuda::OptionalCUDAGuard device_guard(x.device()); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); - TORCH_CHECK_DTYPE(x, kBFloat16); - // TORCH_CHECK_DTYPE(w, kBFloat16); - TORCH_CHECK_DTYPE(g, kBFloat16); TORCH_CHECK_DIV(x, -1, 4); TORCH_CHECK_SHAPES(x, -1, w, 0, 1); - // TORCH_CHECK_SHAPES_FULL(x, y); TORCH_CHECK_SHAPES_FULL(x, g); - bool output_fp32 = y.dtype() == at::kFloat; - bool output_fp16 = y.dtype() == at::kHalf; - bool weight_bf16 = w.dtype() == at::kBFloat16; - bool weight_fp32 = w.dtype() == at::kFloat; - int rows = 1; for (int i = 0; i < x.dim() - 1; ++i) rows *= x.size(i); int dim = x.size(-1); @@ -464,104 +387,46 @@ void gated_rms_norm dim3 blockDim(small ? 32 : NUM_THREADS, 1, 1); dim3 gridDim(rows, 1, 1); - if (!small && output_fp16 && weight_bf16) - gated_rms_norm_kernel<<>> - ( - (const bfloat16*) x.data_ptr(), - (const bfloat16*) w.data_ptr(), - (half*) y.data_ptr(), - (const bfloat16*) g.data_ptr(), - epsilon, - rows, - dim, - constant_bias + auto tx = x.dtype(); + auto tw = w.dtype(); + auto ty = y.dtype(); + auto tg = g.dtype(); + + // Launch macro + #define __(_tx, __tx, _tw, __tw, _ty, __ty, _tg, __tg, _small, __num_threads) \ + if (small == _small && tx == at::_tx && tw == at::_tw && ty == at::_ty && tg == at::_tg) \ + gated_rms_norm_kernel<__num_threads><<>> \ + ( \ + (const __tx*) x.data_ptr(), \ + (const __tw*) w.data_ptr(), \ + (__ty*) y.data_ptr(), \ + (const __tg*) g.data_ptr(), \ + epsilon, \ + rows, \ + dim, \ + constant_bias \ ); - else if (!small && output_fp32 && weight_bf16) - gated_rms_norm_kernel<<>> - ( - (const bfloat16*) x.data_ptr(), - (const bfloat16*) w.data_ptr(), - (float*) y.data_ptr(), - (const bfloat16*) g.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (small && output_fp16 && weight_bf16) - gated_rms_norm_kernel<32><<>> - ( - (const bfloat16*) x.data_ptr(), - (const bfloat16*) w.data_ptr(), - (half*) y.data_ptr(), - (const bfloat16*) g.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (small && output_fp32 && weight_bf16) - gated_rms_norm_kernel<32><<>> - ( - (const bfloat16*) x.data_ptr(), - (const bfloat16*) w.data_ptr(), - (float*) y.data_ptr(), - (const bfloat16*) g.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (!small && output_fp16 && weight_fp32) - gated_rms_norm_kernel<<>> - ( - (const bfloat16*) x.data_ptr(), - (const float*) w.data_ptr(), - (half*) y.data_ptr(), - (const bfloat16*) g.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (!small && output_fp32 && weight_fp32) - gated_rms_norm_kernel<<>> - ( - (const bfloat16*) x.data_ptr(), - (const float*) w.data_ptr(), - (float*) y.data_ptr(), - (const bfloat16*) g.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (small && output_fp16 && weight_fp32) - gated_rms_norm_kernel<32><<>> - ( - (const bfloat16*) x.data_ptr(), - (const float*) w.data_ptr(), - (half*) y.data_ptr(), - (const bfloat16*) g.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else if (small && output_fp32 && weight_fp32) - gated_rms_norm_kernel<32><<>> - ( - (const bfloat16*) x.data_ptr(), - (const float*) w.data_ptr(), - (float*) y.data_ptr(), - (const bfloat16*) g.data_ptr(), - epsilon, - rows, - dim, - constant_bias - ); - else - TORCH_CHECK(false, "gated_rms_norm: Invalid datatypes for input/output, must be half or float") + + // x_type_____________ w_type_____________ y_type_______ g_type_____________ small num_threads + __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kBFloat16, bfloat16, true, 32 ) + else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kBFloat16, bfloat16, false, NUM_THREADS) + else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kBFloat16, bfloat16, true, 32 ) + else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kBFloat16, bfloat16, false, NUM_THREADS) + else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kBFloat16, bfloat16, true, 32 ) + else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kBFloat16, bfloat16, false, NUM_THREADS) + else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kBFloat16, bfloat16, true, 32 ) + else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kBFloat16, bfloat16, false, NUM_THREADS) + else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kFloat, float, true, 32 ) + else __(kBFloat16, bfloat16, kFloat, float, kHalf, half, kFloat, float, false, NUM_THREADS) + else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kFloat, float, true, 32 ) + else __(kBFloat16, bfloat16, kFloat, float, kFloat, float, kFloat, float, false, NUM_THREADS) + else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kFloat, float, true, 32 ) + else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kHalf, half, kFloat, float, false, NUM_THREADS) + else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kFloat, float, true, 32 ) + else __(kBFloat16, bfloat16, kBFloat16, bfloat16, kFloat, float, kFloat, float, false, NUM_THREADS) + + else TORCH_CHECK(false, "gated_rms_norm: Invalid datatypes for input/output"); + #undef __ cuda_check(cudaPeekAtLastError()); } \ No newline at end of file