GatedDeltaNet: Allow bfloat16 a_log

This commit is contained in:
turboderp
2026-03-11 20:24:04 +01:00
parent ad546f7937
commit d52c49c17f

View File

@@ -33,6 +33,16 @@ __device__ __forceinline__ float untrunc_bf16(bfloat16 x)
return __bfloat162float(x);
}
__device__ __forceinline__ float as_float(bfloat16 x)
{
return __bfloat162float(x);
}
__device__ __forceinline__ float as_float(float x)
{
return x;
}
__device__ __forceinline__ float softplus(float x) // beta=1.0, linear threshold=20.0
{
if (x > 20.0f) return x;
@@ -214,13 +224,13 @@ void gated_delta_net_fused_op
cuda_check(cudaPeekAtLastError());
}
template <typename a_log_T>
__global__ void gated_delta_net_fused_op_2_kernel
(
const float* __restrict__ in_b, // [B,S,H]
const float* __restrict__ in_a, // [B,S,H]
const bfloat16* __restrict__ in_dt_bias, // [H]
const float* __restrict__ in_a_log, // [H]
const a_log_T* __restrict__ in_a_log, // [H]
bfloat16* __restrict__ out_beta, // [B,S,H]
float* __restrict__ out_g, // [B,S,H]
int B,
@@ -241,8 +251,8 @@ __global__ void gated_delta_net_fused_op_2_kernel
out_g += row * H + t;
float beta = _sigmoid_fast_exp(*in_b);
float dt_bias = untrunc_bf16(*in_dt_bias);
float g = -softplus(*in_a + dt_bias) * __expf(*in_a_log);
float dt_bias = as_float(*in_dt_bias);
float g = -softplus(*in_a + dt_bias) * __expf(as_float(*in_a_log));
*out_beta = trunc_bf16(beta);
*out_g = g;
@@ -269,10 +279,12 @@ void gated_delta_net_fused_op_2
TORCH_CHECK_DTYPE(b, kFloat);
TORCH_CHECK_DTYPE(a, kFloat);
TORCH_CHECK_DTYPE(dt_bias, kBFloat16);
TORCH_CHECK_DTYPE(a_log, kFloat);
TORCH_CHECK_DTYPE(beta, kBFloat16);
TORCH_CHECK_DTYPE(g, kFloat);
bool a_log_fp32 = a_log.dtype() == at::kFloat;
bool a_log_bf16 = a_log.dtype() == at::kBFloat16;
TORCH_CHECK_SHAPES_FULL(b, a);
TORCH_CHECK_SHAPES(b, 2, dt_bias, 0, 1);
TORCH_CHECK_SHAPES(b, 2, a_log, 0, 1);
@@ -287,19 +299,25 @@ void gated_delta_net_fused_op_2
int threads = rows_per_block * H;
int blocks = CEIL_DIVIDE(B * S, rows_per_block);
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>
(
(const float*) b.data_ptr(),
(const float*) a.data_ptr(),
(const bfloat16*) dt_bias.data_ptr(),
(const float*) a_log.data_ptr(),
(bfloat16*) beta.data_ptr(),
(float*) g.data_ptr(),
B,
S,
H,
#define ARGS(a_log_T) \
(const float*) b.data_ptr(), \
(const float*) a.data_ptr(), \
(const bfloat16*) dt_bias.data_ptr(), \
(const a_log_T*) a_log.data_ptr(), \
(bfloat16*) beta.data_ptr(), \
(float*) g.data_ptr(), \
B, \
S, \
H, \
rows_per_block
);
if (a_log_fp32)
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>(ARGS(float));
else if (a_log_bf16)
gated_delta_net_fused_op_2_kernel<<<blocks, threads, 0, stream>>>(ARGS(bfloat16));
else TORCH_CHECK(false, "gated_delta_net_fused_op_2: unsupported dtype");
#undef ARGS
cuda_check(cudaPeekAtLastError());
}