mirror of
https://github.com/turboderp-org/exllamav3.git
synced 2026-04-20 14:29:51 +00:00
GatedDeltaNet: Allow bfloat16 a_log
This commit is contained in:
@@ -33,6 +33,16 @@ __device__ __forceinline__ float untrunc_bf16(bfloat16 x)
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float as_float(bfloat16 x)
|
||||
{
|
||||
return __bfloat162float(x);
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float as_float(float x)
|
||||
{
|
||||
return x;
|
||||
}
|
||||
|
||||
__device__ __forceinline__ float softplus(float x) // beta=1.0, linear threshold=20.0
|
||||
{
|
||||
if (x > 20.0f) return x;
|
||||
@@ -214,13 +224,13 @@ void gated_delta_net_fused_op
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
|
||||
template <typename a_log_T>
|
||||
__global__ void gated_delta_net_fused_op_2_kernel
|
||||
(
|
||||
const float* __restrict__ in_b, // [B,S,H]
|
||||
const float* __restrict__ in_a, // [B,S,H]
|
||||
const bfloat16* __restrict__ in_dt_bias, // [H]
|
||||
const float* __restrict__ in_a_log, // [H]
|
||||
const a_log_T* __restrict__ in_a_log, // [H]
|
||||
bfloat16* __restrict__ out_beta, // [B,S,H]
|
||||
float* __restrict__ out_g, // [B,S,H]
|
||||
int B,
|
||||
@@ -241,8 +251,8 @@ __global__ void gated_delta_net_fused_op_2_kernel
|
||||
out_g += row * H + t;
|
||||
|
||||
float beta = _sigmoid_fast_exp(*in_b);
|
||||
float dt_bias = untrunc_bf16(*in_dt_bias);
|
||||
float g = -softplus(*in_a + dt_bias) * __expf(*in_a_log);
|
||||
float dt_bias = as_float(*in_dt_bias);
|
||||
float g = -softplus(*in_a + dt_bias) * __expf(as_float(*in_a_log));
|
||||
|
||||
*out_beta = trunc_bf16(beta);
|
||||
*out_g = g;
|
||||
@@ -269,10 +279,12 @@ void gated_delta_net_fused_op_2
|
||||
TORCH_CHECK_DTYPE(b, kFloat);
|
||||
TORCH_CHECK_DTYPE(a, kFloat);
|
||||
TORCH_CHECK_DTYPE(dt_bias, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(a_log, kFloat);
|
||||
TORCH_CHECK_DTYPE(beta, kBFloat16);
|
||||
TORCH_CHECK_DTYPE(g, kFloat);
|
||||
|
||||
bool a_log_fp32 = a_log.dtype() == at::kFloat;
|
||||
bool a_log_bf16 = a_log.dtype() == at::kBFloat16;
|
||||
|
||||
TORCH_CHECK_SHAPES_FULL(b, a);
|
||||
TORCH_CHECK_SHAPES(b, 2, dt_bias, 0, 1);
|
||||
TORCH_CHECK_SHAPES(b, 2, a_log, 0, 1);
|
||||
@@ -287,19 +299,25 @@ void gated_delta_net_fused_op_2
|
||||
int threads = rows_per_block * H;
|
||||
int blocks = CEIL_DIVIDE(B * S, rows_per_block);
|
||||
|
||||
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>
|
||||
(
|
||||
(const float*) b.data_ptr(),
|
||||
(const float*) a.data_ptr(),
|
||||
(const bfloat16*) dt_bias.data_ptr(),
|
||||
(const float*) a_log.data_ptr(),
|
||||
(bfloat16*) beta.data_ptr(),
|
||||
(float*) g.data_ptr(),
|
||||
B,
|
||||
S,
|
||||
H,
|
||||
#define ARGS(a_log_T) \
|
||||
(const float*) b.data_ptr(), \
|
||||
(const float*) a.data_ptr(), \
|
||||
(const bfloat16*) dt_bias.data_ptr(), \
|
||||
(const a_log_T*) a_log.data_ptr(), \
|
||||
(bfloat16*) beta.data_ptr(), \
|
||||
(float*) g.data_ptr(), \
|
||||
B, \
|
||||
S, \
|
||||
H, \
|
||||
rows_per_block
|
||||
);
|
||||
|
||||
if (a_log_fp32)
|
||||
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>(ARGS(float));
|
||||
else if (a_log_bf16)
|
||||
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>(ARGS(bfloat16));
|
||||
else TORCH_CHECK(false, "gated_delta_net_fused_op_2: unsupported dtype");
|
||||
|
||||
#undef ARGS
|
||||
|
||||
cuda_check(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user