mirror of
https://github.com/kvcache-ai/sglang.git
synced 2026-06-30 19:57:52 +00:00
Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: Ma Mingfei <mingfei.ma@intel.com>
1012 lines
34 KiB
C++
1012 lines
34 KiB
C++
#include "common.h"
|
|
#include "vec.h"
|
|
|
|
namespace {
|
|
|
|
// NB: avoid using `at::vec::map<>` on bfloat16 or half
|
|
// Llama4TextL2Norm
|
|
template <typename scalar_t>
|
|
void l2norm_kernel_impl(
|
|
scalar_t* __restrict__ output,
|
|
const scalar_t* __restrict__ input,
|
|
int64_t batch_size,
|
|
int64_t seq_len,
|
|
int64_t hidden_size,
|
|
int64_t input_strideB,
|
|
int64_t input_strideS,
|
|
int64_t output_strideB,
|
|
int64_t output_strideS,
|
|
float eps = 1e-5) {
|
|
using bVec = at::vec::Vectorized<scalar_t>;
|
|
using fVec = at::vec::Vectorized<float>;
|
|
|
|
constexpr int kVecSize = bVec::size();
|
|
at::parallel_for(0, batch_size * seq_len, 0, [&](int64_t begin, int64_t end) {
|
|
int64_t bi{0}, si{0};
|
|
data_index_init(begin, bi, batch_size, si, seq_len);
|
|
for (int64_t i = begin; i < end; ++i) {
|
|
// local ptrs
|
|
scalar_t* __restrict__ out_ptr = output + bi * output_strideB + si * output_strideS;
|
|
const scalar_t* __restrict__ input_ptr = input + bi * input_strideB + si * input_strideS;
|
|
|
|
fVec sum_fvec = fVec(float(0));
|
|
float sum_val = float(0);
|
|
|
|
int64_t d;
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
sum_fvec += x_fvec0 * x_fvec0;
|
|
sum_fvec += x_fvec1 * x_fvec1;
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
sum_val += x_val * x_val;
|
|
}
|
|
|
|
sum_val += vec_reduce_sum(sum_fvec);
|
|
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
|
const fVec scale_fvec = fVec(rsqrt_var);
|
|
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
x_fvec0 = x_fvec0 * scale_fvec;
|
|
x_fvec1 = x_fvec1 * scale_fvec;
|
|
|
|
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
out_bvec.store(out_ptr + d);
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var);
|
|
}
|
|
// move to the next index
|
|
data_index_step(bi, batch_size, si, seq_len);
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t, typename func_t, typename vec_func_t>
|
|
void rmsnorm_kernel_impl(
|
|
scalar_t* __restrict__ output,
|
|
const scalar_t* __restrict__ input,
|
|
const scalar_t* __restrict__ weight,
|
|
int64_t batch_size,
|
|
int64_t seq_len,
|
|
int64_t hidden_size,
|
|
int64_t input_strideB,
|
|
int64_t input_strideS,
|
|
int64_t output_strideB,
|
|
int64_t output_strideS,
|
|
const func_t& f,
|
|
const vec_func_t& vf,
|
|
float eps = 1e-5) {
|
|
using bVec = at::vec::Vectorized<scalar_t>;
|
|
using fVec = at::vec::Vectorized<float>;
|
|
|
|
constexpr int kVecSize = bVec::size();
|
|
at::parallel_for(0, batch_size * seq_len, 0, [&](int64_t begin, int64_t end) {
|
|
int64_t bi{0}, si{0};
|
|
data_index_init(begin, bi, batch_size, si, seq_len);
|
|
for (int64_t i = begin; i < end; ++i) {
|
|
// local ptrs
|
|
scalar_t* __restrict__ out_ptr = output + bi * output_strideB + si * output_strideS;
|
|
const scalar_t* __restrict__ input_ptr = input + bi * input_strideB + si * input_strideS;
|
|
|
|
fVec sum_fvec = fVec(float(0));
|
|
float sum_val = float(0);
|
|
|
|
int64_t d;
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
sum_fvec += x_fvec0 * x_fvec0;
|
|
sum_fvec += x_fvec1 * x_fvec1;
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
sum_val += x_val * x_val;
|
|
}
|
|
|
|
sum_val += vec_reduce_sum(sum_fvec);
|
|
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
|
const fVec scale_fvec = fVec(rsqrt_var);
|
|
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
bVec w_bvec = bVec::loadu(weight + d);
|
|
fVec w_fvec0, w_fvec1;
|
|
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
|
|
|
x_fvec0 = x_fvec0 * scale_fvec * vf(w_fvec0);
|
|
x_fvec1 = x_fvec1 * scale_fvec * vf(w_fvec1);
|
|
|
|
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
out_bvec.store(out_ptr + d);
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
float w_val = static_cast<float>(weight[d]);
|
|
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * f(w_val));
|
|
}
|
|
// move to the next index
|
|
data_index_step(bi, batch_size, si, seq_len);
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
void gemma3_rmsnorm_kernel_4d_impl(
|
|
scalar_t* __restrict__ output,
|
|
const scalar_t* __restrict__ input,
|
|
const scalar_t* __restrict__ weight,
|
|
int64_t batch_size,
|
|
int64_t num_head,
|
|
int64_t seq_len,
|
|
int64_t hidden_size,
|
|
int64_t input_strideB,
|
|
int64_t input_strideH,
|
|
int64_t input_strideS,
|
|
int64_t output_strideB,
|
|
int64_t output_strideH,
|
|
int64_t output_strideS,
|
|
float eps = 1e-5) {
|
|
using bVec = at::vec::Vectorized<scalar_t>;
|
|
using fVec = at::vec::Vectorized<float>;
|
|
|
|
constexpr int kVecSize = bVec::size();
|
|
at::parallel_for(0, batch_size * num_head * seq_len, 0, [&](int64_t begin, int64_t end) {
|
|
int64_t bi{0}, hi{0}, si{0};
|
|
data_index_init(begin, bi, batch_size, hi, num_head, si, seq_len);
|
|
for (int64_t i = begin; i < end; ++i) {
|
|
// local ptrs
|
|
scalar_t* __restrict__ out_ptr = output + bi * output_strideB + hi * output_strideH + si * output_strideS;
|
|
const scalar_t* __restrict__ input_ptr = input + bi * input_strideB + hi * input_strideH + si * input_strideS;
|
|
|
|
fVec sum_fvec = fVec(float(0));
|
|
float sum_val = float(0);
|
|
fVec one_fvec = fVec(float(1));
|
|
|
|
int64_t d;
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
sum_fvec += x_fvec0 * x_fvec0;
|
|
sum_fvec += x_fvec1 * x_fvec1;
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
sum_val += x_val * x_val;
|
|
}
|
|
|
|
sum_val += vec_reduce_sum(sum_fvec);
|
|
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
|
const fVec scale_fvec = fVec(rsqrt_var);
|
|
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
bVec w_bvec = bVec::loadu(weight + d);
|
|
fVec w_fvec0, w_fvec1;
|
|
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
|
|
|
x_fvec0 = x_fvec0 * scale_fvec * (w_fvec0 + one_fvec);
|
|
x_fvec1 = x_fvec1 * scale_fvec * (w_fvec1 + one_fvec);
|
|
|
|
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
out_bvec.store(out_ptr + d);
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
float w_val = static_cast<float>(weight[d]);
|
|
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * (w_val + 1));
|
|
}
|
|
// move to the next index
|
|
data_index_step(bi, batch_size, hi, num_head, si, seq_len);
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t, typename func_t, typename vec_func_t>
|
|
void fused_add_rmsnorm_kernel_impl(
|
|
scalar_t* __restrict__ input,
|
|
scalar_t* __restrict__ residual,
|
|
const scalar_t* __restrict__ weight,
|
|
float* __restrict__ buffer,
|
|
int64_t batch_size,
|
|
int64_t seq_len,
|
|
int64_t hidden_size,
|
|
int64_t input_strideB,
|
|
int64_t input_strideS,
|
|
const func_t& f,
|
|
const vec_func_t& vf,
|
|
float eps = 1e-5) {
|
|
using bVec = at::vec::Vectorized<scalar_t>;
|
|
using fVec = at::vec::Vectorized<float>;
|
|
|
|
constexpr int kVecSize = bVec::size();
|
|
at::parallel_for(0, batch_size * seq_len, 0, [&](int64_t begin, int64_t end) {
|
|
int64_t bi{0}, si{0};
|
|
data_index_init(begin, bi, batch_size, si, seq_len);
|
|
int tid = at::get_thread_num();
|
|
float* __restrict__ buffer_ptr = buffer + tid * hidden_size;
|
|
|
|
for (int64_t i = begin; i < end; ++i) {
|
|
// local ptrs
|
|
scalar_t* __restrict__ input_ptr = input + bi * input_strideB + si * input_strideS;
|
|
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
|
|
|
|
fVec sum_fvec = fVec(float(0));
|
|
float sum_val = float(0);
|
|
|
|
int64_t d;
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
bVec r_bvec = bVec::loadu(residual_ptr + d);
|
|
fVec r_fvec0, r_fvec1;
|
|
std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec);
|
|
|
|
x_fvec0 += r_fvec0;
|
|
x_fvec1 += r_fvec1;
|
|
|
|
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
out_bvec.store(residual_ptr + d);
|
|
|
|
sum_fvec += x_fvec0 * x_fvec0;
|
|
sum_fvec += x_fvec1 * x_fvec1;
|
|
|
|
x_fvec0.store(buffer_ptr + d);
|
|
x_fvec1.store(buffer_ptr + d + fVec::size());
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
float r_val = static_cast<float>(residual_ptr[d]);
|
|
|
|
x_val += r_val;
|
|
residual_ptr[d] = static_cast<scalar_t>(x_val);
|
|
|
|
sum_val += x_val * x_val;
|
|
buffer_ptr[d] = x_val;
|
|
}
|
|
|
|
sum_val += vec_reduce_sum(sum_fvec);
|
|
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
|
const fVec scale_fvec = fVec(rsqrt_var);
|
|
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
fVec x_fvec0 = fVec::loadu(buffer_ptr + d);
|
|
fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
|
|
|
|
bVec w_bvec = bVec::loadu(weight + d);
|
|
fVec w_fvec0, w_fvec1;
|
|
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
|
|
|
x_fvec0 = x_fvec0 * scale_fvec * vf(w_fvec0);
|
|
x_fvec1 = x_fvec1 * scale_fvec * vf(w_fvec1);
|
|
bVec x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
x_bvec.store(input_ptr + d);
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = buffer_ptr[d] * rsqrt_var * static_cast<float>(f(weight[d]));
|
|
input_ptr[d] = x_val;
|
|
}
|
|
// move to the next index
|
|
data_index_step(bi, batch_size, si, seq_len);
|
|
}
|
|
});
|
|
}
|
|
|
|
template <typename scalar_t>
|
|
void fused_rmsnorm_gated_kernel_impl(
|
|
scalar_t* __restrict__ output,
|
|
const scalar_t* __restrict__ input,
|
|
const scalar_t* __restrict__ weight,
|
|
const scalar_t* __restrict__ gate,
|
|
int64_t batch_size,
|
|
int64_t hidden_size,
|
|
int64_t input_strideN,
|
|
float eps = 1e-5) {
|
|
using bVec = at::vec::Vectorized<scalar_t>;
|
|
using fVec = at::vec::Vectorized<float>;
|
|
const fVec one = fVec(1.f);
|
|
|
|
constexpr int kVecSize = bVec::size();
|
|
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
|
|
for (int64_t i = begin; i < end; ++i) {
|
|
// local ptrs
|
|
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
|
|
const scalar_t* __restrict__ input_ptr = input + i * input_strideN;
|
|
const scalar_t* __restrict__ gate_ptr = gate + i * hidden_size;
|
|
|
|
fVec sum_fvec = fVec(float(0));
|
|
float sum_val = float(0);
|
|
|
|
int64_t d;
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
sum_fvec += x_fvec0 * x_fvec0;
|
|
sum_fvec += x_fvec1 * x_fvec1;
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
sum_val += x_val * x_val;
|
|
}
|
|
|
|
sum_val += vec_reduce_sum(sum_fvec);
|
|
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
|
|
const fVec scale_fvec = fVec(rsqrt_var);
|
|
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
bVec w_bvec = bVec::loadu(weight + d);
|
|
fVec w_fvec0, w_fvec1;
|
|
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
|
|
|
bVec g_bvec = bVec::loadu(gate_ptr + d);
|
|
fVec g_fvec0, g_fvec1;
|
|
std::tie(g_fvec0, g_fvec1) = at::vec::convert_to_float(g_bvec);
|
|
g_fvec0 = g_fvec0 / (one + g_fvec0.neg().exp_u20());
|
|
g_fvec1 = g_fvec1 / (one + g_fvec1.neg().exp_u20());
|
|
|
|
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0 * g_fvec0;
|
|
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1 * g_fvec1;
|
|
|
|
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
out_bvec.store(out_ptr + d);
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
float w_val = static_cast<float>(weight[d]);
|
|
float g_val = static_cast<float>(gate_ptr[d]);
|
|
|
|
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * w_val * g_val / (1.f + std::exp(-g_val)));
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
template <typename scalar_t>
|
|
void fused_add_layernorm_kernel_impl(
|
|
scalar_t* __restrict__ output,
|
|
const scalar_t* __restrict__ input,
|
|
scalar_t* __restrict__ residual,
|
|
const scalar_t* __restrict__ weight,
|
|
const scalar_t* __restrict__ bias,
|
|
float* __restrict__ buffer,
|
|
int64_t batch_size,
|
|
int64_t seq_len,
|
|
int64_t hidden_size,
|
|
int64_t input_strideN,
|
|
float eps = 1e-5) {
|
|
using bVec = at::vec::Vectorized<scalar_t>;
|
|
using fVec = at::vec::Vectorized<float>;
|
|
constexpr int kVecSize = bVec::size();
|
|
|
|
const bool has_residual{residual != nullptr};
|
|
const bool has_bias{bias != nullptr};
|
|
const int64_t parallel_size{batch_size * seq_len};
|
|
at::parallel_for(0, parallel_size, 0, [&](int64_t begin, int64_t end) {
|
|
float* __restrict__ buffer_ptr = buffer + at::get_thread_num() * hidden_size;
|
|
|
|
for (int64_t i = begin; i < end; ++i) {
|
|
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
|
|
const scalar_t* __restrict__ input_ptr = input + i * input_strideN;
|
|
scalar_t* __restrict__ residual_ptr{(scalar_t*)nullptr};
|
|
if (has_residual) {
|
|
residual_ptr = residual + i * hidden_size;
|
|
}
|
|
|
|
// First pass: compute mean and var
|
|
fVec sum_fvec{fVec(0.0)}, sum_sq_fvec{fVec(0.0)};
|
|
float sum_val{0.0}, sum_sq_val{0.0};
|
|
int64_t d{0};
|
|
|
|
#pragma GCC unroll 4
|
|
for (; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
bVec x_bvec = bVec::loadu(input_ptr + d);
|
|
fVec x_fvec0, x_fvec1;
|
|
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
|
|
|
if (has_residual) {
|
|
bVec r_bvec = bVec::loadu(residual_ptr + d);
|
|
fVec r_fvec0, r_fvec1;
|
|
std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec);
|
|
|
|
x_fvec0 += r_fvec0;
|
|
x_fvec1 += r_fvec1;
|
|
|
|
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
out_bvec.store(residual_ptr + d);
|
|
}
|
|
|
|
sum_fvec += x_fvec0;
|
|
sum_fvec += x_fvec1;
|
|
sum_sq_fvec += x_fvec0 * x_fvec0;
|
|
sum_sq_fvec += x_fvec1 * x_fvec1;
|
|
|
|
x_fvec0.store(buffer_ptr + d);
|
|
x_fvec1.store(buffer_ptr + d + fVec::size());
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float x_val = static_cast<float>(input_ptr[d]);
|
|
if (has_residual) {
|
|
float r_val = static_cast<float>(residual_ptr[d]);
|
|
x_val += r_val;
|
|
residual_ptr[d] = static_cast<scalar_t>(x_val);
|
|
}
|
|
|
|
sum_val += x_val;
|
|
sum_sq_val += x_val * x_val;
|
|
buffer_ptr[d] = x_val;
|
|
}
|
|
|
|
// Var(X) = E(X^2) - (E(X))^2
|
|
// Refer to FlashInfer impl:
|
|
// https://github.com/flashinfer-ai/flashinfer/blob/6bb01d19c2d9ab3b6a3a5e9e97448891a5ed2844/include/flashinfer/norm.cuh#L554
|
|
sum_val += vec_reduce_sum(sum_fvec);
|
|
sum_sq_val += vec_reduce_sum(sum_sq_fvec);
|
|
|
|
float mean{sum_val / hidden_size};
|
|
float mean_sq{sum_sq_val / hidden_size};
|
|
float variance{mean_sq - (mean * mean)};
|
|
float rsqrt_var{float(1) / std::sqrt(variance + eps)};
|
|
|
|
const fVec mean_fvec = fVec(mean);
|
|
const fVec scale_fvec = fVec(rsqrt_var);
|
|
|
|
// Second pass: apply normalization
|
|
#pragma GCC unroll 4
|
|
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
|
|
fVec x_fvec0 = fVec::loadu(buffer_ptr + d);
|
|
fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
|
|
bVec w_bvec = bVec::loadu(weight + d);
|
|
fVec w_fvec0, w_fvec1;
|
|
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
|
|
|
|
x_fvec0 = (x_fvec0 - mean_fvec) * scale_fvec * w_fvec0;
|
|
x_fvec1 = (x_fvec1 - mean_fvec) * scale_fvec * w_fvec1;
|
|
|
|
if (has_bias) {
|
|
bVec b_bvec = bVec::loadu(bias + d);
|
|
fVec b_fvec0, b_fvec1;
|
|
std::tie(b_fvec0, b_fvec1) = at::vec::convert_to_float(b_bvec);
|
|
x_fvec0 += b_fvec0;
|
|
x_fvec1 += b_fvec1;
|
|
}
|
|
|
|
bVec o_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
|
|
o_bvec.store(out_ptr + d);
|
|
}
|
|
#pragma GCC unroll 4
|
|
for (; d < hidden_size; ++d) {
|
|
float normalized = (buffer_ptr[d] - mean) * rsqrt_var;
|
|
float x_val = normalized * static_cast<float>(weight[d]);
|
|
if (has_bias) {
|
|
x_val += static_cast<float>(bias[d]);
|
|
}
|
|
out_ptr[d] = static_cast<scalar_t>(x_val);
|
|
}
|
|
}
|
|
});
|
|
} // anonymous namespace
|
|
|
|
// input : {batch_size, hidden_size}
|
|
at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::l2norm_cpu", std::vector<c10::IValue>({input}));
|
|
|
|
CHECK_INPUT(input);
|
|
CHECK_DIM(2, input);
|
|
int64_t batch_size = input.size(0);
|
|
int64_t hidden_size = input.size(1);
|
|
at::Tensor output = at::empty_like(input);
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "l2norm_kernel", [&] {
|
|
l2norm_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
batch_size,
|
|
1,
|
|
hidden_size,
|
|
hidden_size,
|
|
0,
|
|
hidden_size,
|
|
0,
|
|
eps);
|
|
});
|
|
return output;
|
|
}
|
|
|
|
// input : {batch_size, hidden_size} or {batch_size, seq_len, hidden_size}
|
|
// weight: {hidden_size}
|
|
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
|
|
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(weight);
|
|
int64_t inp_dim{input.dim()};
|
|
TORCH_CHECK(inp_dim == 2 || inp_dim == 3, "Expected input dim to be 2 or 3, but got ", inp_dim);
|
|
CHECK_DIM(1, weight);
|
|
CHECK_EQ(input.size(-1), weight.size(0));
|
|
|
|
int64_t batch_size = input.size(0);
|
|
int64_t seq_len = 1;
|
|
int64_t hidden_size = input.size(-1);
|
|
int64_t input_strideB = input.stride(0);
|
|
int64_t input_strideS = 0;
|
|
at::Tensor output = at::empty_like(input);
|
|
int64_t output_strideB = output.stride(0);
|
|
int64_t output_strideS = 0;
|
|
if (inp_dim == 3) {
|
|
seq_len = input.size(1);
|
|
input_strideS = input.stride(1);
|
|
output_strideS = output.stride(1);
|
|
}
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
rmsnorm_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
batch_size,
|
|
seq_len,
|
|
hidden_size,
|
|
input_strideB,
|
|
input_strideS,
|
|
output_strideB,
|
|
output_strideS,
|
|
[](float x) { return x; },
|
|
[](Vec x) { return x; },
|
|
eps);
|
|
});
|
|
return output;
|
|
}
|
|
|
|
// input : {batch_size, hidden_size} or {batch_size, seq_len, hidden_size}
|
|
// weight: {hidden_size}
|
|
// bias : {hidden_size}
|
|
at::Tensor
|
|
layernorm_cpu(const at::Tensor& input, const at::Tensor& weight, const std::optional<at::Tensor>& bias, double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::layernorm_cpu", std::vector<c10::IValue>({input, weight}));
|
|
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(weight);
|
|
int64_t inp_dim{input.dim()};
|
|
TORCH_CHECK(inp_dim == 2 || inp_dim == 3, "Expected input dim to be 2 or 3, but got ", inp_dim);
|
|
CHECK_DIM(1, weight);
|
|
if (bias.has_value()) {
|
|
CHECK_DIM(1, bias.value());
|
|
CHECK_EQ(bias.value().size(0), weight.size(0));
|
|
}
|
|
|
|
int64_t batch_size{input.size(0)}, seq_len{1}, hidden_size{input.size(1)}, input_strideN{input.stride(0)};
|
|
if (inp_dim == 3) {
|
|
CHECK_EQ(input.size(2), weight.size(0));
|
|
seq_len = input.size(1);
|
|
hidden_size = input.size(2);
|
|
input_strideN = input.stride(1);
|
|
} else {
|
|
CHECK_EQ(input.size(1), weight.size(0));
|
|
}
|
|
|
|
at::Tensor output = at::empty_like(input);
|
|
int64_t num_threads = at::get_num_threads();
|
|
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "layernorm_kernel", [&] {
|
|
fused_add_layernorm_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
nullptr,
|
|
weight.data_ptr<scalar_t>(),
|
|
conditional_data_ptr<scalar_t>(bias),
|
|
buffer.data_ptr<float>(),
|
|
batch_size,
|
|
seq_len,
|
|
hidden_size,
|
|
input_strideN,
|
|
eps);
|
|
});
|
|
return output;
|
|
}
|
|
|
|
at::Tensor gemma_rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::gemma_rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
|
|
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(weight);
|
|
CHECK_DIM(2, input);
|
|
CHECK_DIM(1, weight);
|
|
CHECK_EQ(input.size(1), weight.size(0));
|
|
int64_t batch_size = input.size(0);
|
|
int64_t hidden_size = input.size(1);
|
|
at::Tensor output = at::empty_like(input);
|
|
int64_t input_strideN = input.stride(0);
|
|
int64_t output_strideN = output.stride(0);
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gemma_rmsnorm_kernel", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
Vec one_vec = Vec(float(1));
|
|
rmsnorm_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
batch_size,
|
|
1,
|
|
hidden_size,
|
|
input_strideN,
|
|
0,
|
|
output_strideN,
|
|
0,
|
|
[](float x) { return x + 1; },
|
|
[one_vec](Vec x) { return x + one_vec; },
|
|
eps);
|
|
});
|
|
return output;
|
|
}
|
|
|
|
// input : {batch_size, hidden_size} or {batch_size, num_head, seq_len, head_dim}
|
|
// weight: {hidden_size}
|
|
at::Tensor gemma3_rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::gemma3_rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
|
|
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(weight);
|
|
TORCH_CHECK(
|
|
input.dim() == 2 || input.dim() == 4, "gemma3_rmsnorm_cpu: input must be 2D or 4D, got ", input.dim(), "D");
|
|
CHECK_DIM(1, weight);
|
|
CHECK_EQ(input.size(-1), weight.size(0));
|
|
int64_t batch_size = input.size(0);
|
|
int64_t hidden_size = weight.size(0);
|
|
at::Tensor output = at::empty_like(input);
|
|
if (input.dim() == 2) {
|
|
int64_t input_strideN = input.stride(0);
|
|
int64_t output_strideN = output.stride(0);
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gemma3_rmsnorm_kernel", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
Vec one_vec = Vec(float(1));
|
|
rmsnorm_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
batch_size,
|
|
1,
|
|
hidden_size,
|
|
input_strideN,
|
|
0,
|
|
output_strideN,
|
|
0,
|
|
[](float x) { return x + 1; },
|
|
[one_vec](Vec x) { return x + one_vec; },
|
|
eps);
|
|
});
|
|
} else {
|
|
int64_t input_strideB = input.stride(0);
|
|
int64_t input_strideH = input.stride(1);
|
|
int64_t input_strideS = input.stride(2);
|
|
int64_t output_strideB = output.stride(0);
|
|
int64_t output_strideH = output.stride(1);
|
|
int64_t output_strideS = output.stride(2);
|
|
int64_t num_head = input.size(1);
|
|
int64_t seq_len = input.size(2);
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gemma3_rmsnorm_kernel", [&] {
|
|
gemma3_rmsnorm_kernel_4d_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
batch_size,
|
|
num_head,
|
|
seq_len,
|
|
hidden_size,
|
|
input_strideB,
|
|
input_strideH,
|
|
input_strideS,
|
|
output_strideB,
|
|
output_strideH,
|
|
output_strideS,
|
|
eps);
|
|
});
|
|
}
|
|
return output;
|
|
}
|
|
|
|
// Gemma4RMSNorm: with_scale ? norm(x) * (weight + scale_shift) : norm(x)
|
|
// input : {batch_size, hidden_size} or {batch_size, seq_len, hidden_size}
|
|
// weight: {hidden_size}
|
|
at::Tensor gemma4_rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps, double scale_shift, bool with_scale) {
|
|
RECORD_FUNCTION("sgl-kernel::gemma4_rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
|
|
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(weight);
|
|
int64_t inp_dim{input.dim()};
|
|
TORCH_CHECK(inp_dim == 2 || inp_dim == 3, "gemma4_rmsnorm_cpu: expected input dim 2 or 3, got ", inp_dim);
|
|
CHECK_DIM(1, weight);
|
|
CHECK_EQ(input.size(-1), weight.size(0));
|
|
|
|
int64_t hidden_size = input.size(-1);
|
|
at::Tensor output = at::empty_like(input);
|
|
int64_t batch_size = input.size(0);
|
|
int64_t seq_len = 1;
|
|
int64_t input_strideB = input.stride(0);
|
|
int64_t input_strideS = 0;
|
|
int64_t output_strideB = output.stride(0);
|
|
int64_t output_strideS = 0;
|
|
if (inp_dim == 3) {
|
|
seq_len = input.size(1);
|
|
input_strideS = input.stride(1);
|
|
output_strideS = output.stride(1);
|
|
}
|
|
|
|
if (with_scale) {
|
|
float shift = static_cast<float>(scale_shift);
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gemma4_rmsnorm_kernel", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
Vec shift_vec = Vec(shift);
|
|
rmsnorm_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
batch_size,
|
|
seq_len,
|
|
hidden_size,
|
|
input_strideB,
|
|
input_strideS,
|
|
output_strideB,
|
|
output_strideS,
|
|
[shift](float x) { return x + shift; },
|
|
[shift_vec](Vec x) { return x + shift_vec; },
|
|
eps);
|
|
});
|
|
} else {
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gemma4_rmsnorm_kernel", [&] {
|
|
l2norm_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
batch_size,
|
|
seq_len,
|
|
hidden_size,
|
|
input_strideB,
|
|
input_strideS,
|
|
output_strideB,
|
|
output_strideS,
|
|
eps);
|
|
});
|
|
}
|
|
return output;
|
|
}
|
|
|
|
// input : {batch_size, hidden_size}
|
|
// weight: {hidden_size}
|
|
// gate: {batch_size, hidden_size}
|
|
at::Tensor fused_rmsnorm_gated_cpu(at::Tensor& input, at::Tensor& weight, at::Tensor& gate, double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::fused_rmsnorm_gated_cpu", std::vector<c10::IValue>({input, weight, gate}));
|
|
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(weight);
|
|
CHECK_INPUT(gate);
|
|
CHECK_DIM(2, input);
|
|
CHECK_DIM(1, weight);
|
|
CHECK_DIM(2, gate);
|
|
CHECK_EQ(input.size(1), weight.size(0));
|
|
int64_t batch_size = input.size(0);
|
|
int64_t hidden_size = input.size(1);
|
|
CHECK_EQ(input.size(0), gate.size(0));
|
|
CHECK_EQ(input.size(1), gate.size(1));
|
|
at::Tensor output = at::empty_like(input);
|
|
int64_t input_strideN = input.stride(0);
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_rmsnorm_gated_kernel", [&] {
|
|
fused_rmsnorm_gated_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
gate.data_ptr<scalar_t>(),
|
|
batch_size,
|
|
hidden_size,
|
|
input_strideN,
|
|
eps);
|
|
});
|
|
return output;
|
|
}
|
|
|
|
// input : {batch_size, hidden_size} or {batch_size, seq_len, hidden_size}
|
|
// residual: {batch_size, hidden_size} or {batch_size, seq_len, hidden_size}
|
|
// weight : {hidden_size}
|
|
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(residual);
|
|
CHECK_INPUT(weight);
|
|
int64_t inp_dim{input.dim()}, res_dim{residual.dim()};
|
|
CHECK_EQ(inp_dim, res_dim);
|
|
TORCH_CHECK(inp_dim == 2 || inp_dim == 3, "Expected input dim to be 2 or 3, but got ", inp_dim);
|
|
CHECK_DIM(1, weight);
|
|
CHECK_EQ(input.size(0), residual.size(0));
|
|
CHECK_EQ(input.size(-1), residual.size(-1));
|
|
CHECK_EQ(input.size(-1), weight.size(0));
|
|
|
|
int64_t batch_size = input.size(0);
|
|
int64_t seq_len = 1;
|
|
int64_t hidden_size = input.size(-1);
|
|
int64_t input_strideB = input.stride(0);
|
|
int64_t input_strideS = 0;
|
|
if (inp_dim == 3) {
|
|
seq_len = input.size(1);
|
|
input_strideS = input.stride(1);
|
|
}
|
|
|
|
// allocate temp buffer to store x in float32 per thread
|
|
// TODO: implement a singleton for context
|
|
int64_t num_threads = at::get_num_threads();
|
|
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_rmsnorm_kernel", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
fused_add_rmsnorm_kernel_impl<scalar_t>(
|
|
input.data_ptr<scalar_t>(),
|
|
residual.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
buffer.data_ptr<float>(),
|
|
batch_size,
|
|
seq_len,
|
|
hidden_size,
|
|
input_strideB,
|
|
input_strideS,
|
|
[](float x) { return x; },
|
|
[](Vec x) { return x; },
|
|
eps);
|
|
});
|
|
}
|
|
|
|
// input : {batch_size, hidden_size}
|
|
// residual: {batch_size, hidden_size}
|
|
// weight : {hidden_size}
|
|
void gemma_fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::gemma_fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(residual);
|
|
CHECK_INPUT(weight);
|
|
CHECK_DIM(2, input);
|
|
CHECK_DIM(2, residual);
|
|
CHECK_DIM(1, weight);
|
|
CHECK_EQ(input.size(0), residual.size(0));
|
|
CHECK_EQ(input.size(1), residual.size(1));
|
|
CHECK_EQ(input.size(1), weight.size(0));
|
|
int64_t batch_size = input.size(0);
|
|
int64_t hidden_size = input.size(1);
|
|
int64_t input_strideN = input.stride(0);
|
|
|
|
// allocate temp buffer to store x in float32 per thread
|
|
// TODO: implement a singleton for context
|
|
int64_t num_threads = at::get_num_threads();
|
|
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "gemma_fused_add_rmsnorm_kernel", [&] {
|
|
using Vec = at::vec::Vectorized<float>;
|
|
Vec one_vec = Vec(float(1));
|
|
fused_add_rmsnorm_kernel_impl<scalar_t>(
|
|
input.data_ptr<scalar_t>(),
|
|
residual.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
buffer.data_ptr<float>(),
|
|
batch_size,
|
|
1,
|
|
hidden_size,
|
|
input_strideN,
|
|
0,
|
|
[](float x) { return x + 1; },
|
|
[one_vec](Vec x) { return x + one_vec; },
|
|
eps);
|
|
});
|
|
}
|
|
|
|
// input : {batch_size, hidden_size} or {batch_size, seq_len, hidden_size}
|
|
// residual: {batch_size, hidden_size} or {batch_size, seq_len, hidden_size}
|
|
// weight : {hidden_size}
|
|
// bias : {hidden_size}
|
|
at::Tensor fused_add_layernorm_cpu(
|
|
const at::Tensor& input,
|
|
at::Tensor& residual,
|
|
const at::Tensor& weight,
|
|
const std::optional<at::Tensor>& bias,
|
|
double eps) {
|
|
RECORD_FUNCTION("sgl-kernel::fused_add_layernorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
|
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
|
|
CHECK_INPUT(residual);
|
|
CHECK_INPUT(weight);
|
|
int64_t inp_dim{input.dim()}, res_dim{residual.dim()};
|
|
CHECK_EQ(inp_dim, res_dim);
|
|
TORCH_CHECK(inp_dim == 2 || inp_dim == 3, "Expected input dim to be 2 or 3, but got ", inp_dim);
|
|
TORCH_CHECK(res_dim == 2 || res_dim == 3, "Expected residual dim to be 2 or 3, but got ", res_dim);
|
|
|
|
CHECK_DIM(1, weight);
|
|
if (bias.has_value()) {
|
|
CHECK_DIM(1, bias.value());
|
|
CHECK_EQ(bias.value().size(0), weight.size(0));
|
|
}
|
|
CHECK_EQ(input.size(0), residual.size(0));
|
|
CHECK_EQ(input.size(1), residual.size(1));
|
|
if (inp_dim == 3) {
|
|
CHECK_EQ(input.size(2), residual.size(2));
|
|
CHECK_EQ(input.size(2), weight.size(0));
|
|
} else {
|
|
CHECK_EQ(input.size(1), weight.size(0));
|
|
}
|
|
|
|
int64_t batch_size{input.size(0)}, seq_len{1}, hidden_size{input.size(1)}, input_strideN{input.stride(0)};
|
|
if (inp_dim == 3) {
|
|
seq_len = input.size(1);
|
|
hidden_size = input.size(2);
|
|
input_strideN = input.stride(1);
|
|
}
|
|
at::Tensor output = at::empty_like(input);
|
|
|
|
// Allocate temp buffer to store x in float32 per thread
|
|
// It is necessary to store FP32 precision of residual-add results to pass UT acc test
|
|
int64_t num_threads = at::get_num_threads();
|
|
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
|
|
|
|
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_layernorm_kernel", [&] {
|
|
fused_add_layernorm_kernel_impl<scalar_t>(
|
|
output.data_ptr<scalar_t>(),
|
|
input.data_ptr<scalar_t>(),
|
|
residual.data_ptr<scalar_t>(),
|
|
weight.data_ptr<scalar_t>(),
|
|
conditional_data_ptr<scalar_t>(bias),
|
|
buffer.data_ptr<float>(),
|
|
batch_size,
|
|
seq_len,
|
|
hidden_size,
|
|
input_strideN,
|
|
eps);
|
|
});
|
|
return output;
|
|
}
|