mirror of
https://github.com/ikawrakow/ik_llama.cpp.git
synced 2026-04-22 07:29:23 +00:00
gpt-oss: Seems to be working on CUDA
This commit is contained in:
@@ -37,6 +37,7 @@
|
||||
#include "ggml-cuda/unary.cuh"
|
||||
#include "ggml-cuda/upscale.cuh"
|
||||
#include "ggml-cuda/conv-transpose-1d.cuh"
|
||||
#include "ggml-cuda/add-id.cuh"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
@@ -2851,6 +2852,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_OP_ADD:
|
||||
ggml_cuda_op_add(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_ADD_ID:
|
||||
ggml_cuda_op_add_id(ctx, dst);
|
||||
break;
|
||||
case GGML_OP_MULTI_ADD:
|
||||
ggml_cuda_op_multi_add(ctx, dst);
|
||||
break;
|
||||
@@ -2877,6 +2881,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||
case GGML_UNARY_OP_SWIGLU:
|
||||
ggml_cuda_op_swiglu(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||
ggml_cuda_op_swiglu_oai(ctx, dst);
|
||||
break;
|
||||
case GGML_UNARY_OP_GELU_QUICK:
|
||||
ggml_cuda_op_gelu_quick(ctx, dst);
|
||||
break;
|
||||
@@ -3440,6 +3447,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_UNARY_OP_GELU:
|
||||
case GGML_UNARY_OP_SILU:
|
||||
case GGML_UNARY_OP_SWIGLU:
|
||||
case GGML_UNARY_OP_SWIGLU_OAI:
|
||||
case GGML_UNARY_OP_RELU:
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
case GGML_UNARY_OP_HARDSIGMOID:
|
||||
@@ -3629,6 +3637,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||
case GGML_OP_PERMUTE:
|
||||
case GGML_OP_TRANSPOSE:
|
||||
case GGML_OP_ADD:
|
||||
case GGML_OP_ADD_ID:
|
||||
case GGML_OP_MULTI_ADD:
|
||||
case GGML_OP_MUL:
|
||||
case GGML_OP_DIV:
|
||||
|
||||
@@ -108,6 +108,23 @@ static const char * cu_get_error_str(CUresult err) {
|
||||
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
|
||||
#endif
|
||||
|
||||
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
|
||||
const int id = ggml_cuda_get_device(); \
|
||||
if (!shared_memory_limit_raised[id]) { \
|
||||
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
|
||||
shared_memory_limit_raised[id] = true; \
|
||||
} \
|
||||
} while (0)
|
||||
#else
|
||||
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
|
||||
do { \
|
||||
GGML_UNUSED(nbytes); \
|
||||
} while (0)
|
||||
#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
|
||||
|
||||
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
|
||||
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
|
||||
#else
|
||||
|
||||
@@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -747,6 +748,7 @@ void launch_fattn(
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
@@ -837,6 +839,7 @@ void launch_fattn(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *) sinks->data) : nullptr,
|
||||
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, softcap, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
@@ -1008,7 +1011,8 @@ void launch_fattn_mma(
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
@@ -1162,6 +1166,7 @@ void launch_fattn_mma(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
|
||||
@@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const half2 * const __restrict__ V_h2,
|
||||
const half2 * const __restrict__ mask_h2,
|
||||
const float * const __restrict__ sinks_f,
|
||||
float2 * const __restrict__ dstk,
|
||||
float2 * const __restrict__ dstk_fixup,
|
||||
const float scale,
|
||||
@@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
|
||||
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
||||
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
||||
// so it's being done unconditionally for every thread.
|
||||
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
||||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
||||
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
||||
const float sink = sinks_f[jc % ncols2];
|
||||
|
||||
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
||||
KQ_max_scale[col] = expf(KQ_max_diff);
|
||||
KQ_max[col] = KQ_max_new;
|
||||
|
||||
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
||||
|
||||
const float KQ_max_add = expf(sink - KQ_max_new);
|
||||
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
||||
}
|
||||
|
||||
if (ntiles == 1) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
||||
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Write VKQ accumulators to shared memory in column-major format.
|
||||
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// Also for np > 1 the combination is done via these values in shared memory.
|
||||
@@ -823,6 +870,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -896,6 +944,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel * ncols2 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
@@ -906,12 +955,12 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
|
||||
@@ -934,6 +983,7 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
@@ -943,10 +993,10 @@ static __global__ void flash_attn_mma_ext_f16(
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
|
||||
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
|
||||
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
|
||||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
|
||||
|
||||
@@ -818,6 +818,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
const half2 * const __restrict__ K_h2,
|
||||
const half2 * const __restrict__ V_h2,
|
||||
const half2 * const __restrict__ mask_h2,
|
||||
const float * const __restrict__ sinks_f,
|
||||
float2 * const __restrict__ dstk,
|
||||
float2 * const __restrict__ dstk_fixup,
|
||||
const float scale,
|
||||
@@ -989,6 +990,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
|
||||
// If attention sinks are used, potentially re-scale if KQ_max is small.
|
||||
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
|
||||
// so it's being done unconditionally for every thread.
|
||||
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
|
||||
float KQ_max_scale[cols_per_thread];
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
|
||||
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
|
||||
const float sink = sinks_f[jc % ncols2];
|
||||
|
||||
const float KQ_max_new = fmaxf(KQ_max[col], sink);
|
||||
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
|
||||
KQ_max_scale[col] = expf(KQ_max_diff);
|
||||
KQ_max[col] = KQ_max_new;
|
||||
|
||||
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
|
||||
|
||||
const float KQ_max_add = expf(sink - KQ_max_new);
|
||||
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
|
||||
}
|
||||
|
||||
if (ntiles == 1) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
|
||||
#pragma unroll
|
||||
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
|
||||
VKQ_C[i].x[l] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int col = 0; col < cols_per_thread; ++col) {
|
||||
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
|
||||
#pragma unroll
|
||||
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
|
||||
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Combine VKQ accumulator values if np > 1.
|
||||
// It's also faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
|
||||
// So also write VKQ accumulators to shared memory in column-major format if np == 1.
|
||||
@@ -1222,7 +1269,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
|
||||
}
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
|
||||
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(sinks_f);
|
||||
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
|
||||
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
|
||||
@@ -1239,6 +1286,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -1323,6 +1371,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
@@ -1335,12 +1384,12 @@ static __global__ void flash_attn_ext_f16(
|
||||
if (kb0_start == 0) {
|
||||
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
} else {
|
||||
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
}
|
||||
|
||||
@@ -1362,6 +1411,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
|
||||
@@ -1373,7 +1423,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
|
||||
constexpr bool needs_fixup = false;
|
||||
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
|
||||
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
|
||||
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
@@ -1535,7 +1585,8 @@ static void launch_fattn_new_mma(
|
||||
const ggml_tensor * K = dst->src[1];
|
||||
const ggml_tensor * V = dst->src[2];
|
||||
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * mask = dst->src[3];
|
||||
const ggml_tensor * sinks = dst->src[4];
|
||||
|
||||
ggml_tensor * KQV = dst;
|
||||
|
||||
@@ -1709,6 +1760,7 @@ static void launch_fattn_new_mma(
|
||||
K_data,
|
||||
V_data,
|
||||
mask ? ((const char *) mask->data) : nullptr,
|
||||
sinks ? ((const char *)sinks->data) : nullptr,
|
||||
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
|
||||
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
|
||||
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
|
||||
@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
|
||||
@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
|
||||
@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
//
|
||||
|
||||
// TODO: attention sinks !!!
|
||||
|
||||
#include "common.cuh"
|
||||
#include "fattn-common.cuh"
|
||||
|
||||
@@ -22,6 +24,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const char * __restrict__ K,
|
||||
const char * __restrict__ V,
|
||||
const char * __restrict__ mask,
|
||||
const char * __restrict__ sinks,
|
||||
float * __restrict__ dst,
|
||||
float2 * __restrict__ dst_meta,
|
||||
const float scale,
|
||||
@@ -93,6 +96,7 @@ static __global__ void flash_attn_ext_f16(
|
||||
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
|
||||
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
|
||||
const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_K = nb11 / sizeof(half);
|
||||
|
||||
@@ -539,7 +539,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||
return;
|
||||
}
|
||||
|
||||
// As mentioned above, the new new MMA is slower than then the new MMA.
|
||||
// As mentioned above, the new-new MMA is slower then the new MMA.
|
||||
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
|
||||
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
|
||||
}
|
||||
|
||||
@@ -19,7 +19,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
|
||||
}
|
||||
|
||||
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
|
||||
static __global__ void soft_max_f32_nosinks(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
|
||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
@@ -124,7 +124,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
|
||||
static void soft_max_f32_cuda_nosinks(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
@@ -142,39 +142,40 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
|
||||
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
|
||||
switch (ncols_x) {
|
||||
case 32:
|
||||
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 64:
|
||||
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 128:
|
||||
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 256:
|
||||
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 512:
|
||||
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 1024:
|
||||
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 2048:
|
||||
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
case 4096:
|
||||
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
default:
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
const size_t shmem_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
soft_max_f32_nosinks<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
|
||||
}
|
||||
}
|
||||
|
||||
#if 0
|
||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
@@ -205,13 +206,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
if (use_f16) {
|
||||
const half * src1_dd = (const half *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||
} else {
|
||||
const float * src1_dd = (const float *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
@@ -241,10 +243,283 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
|
||||
if (use_f16) {
|
||||
const half * src1_dd = (const half *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||
} else {
|
||||
const float * src1_dd = (const float *)src1_d;
|
||||
|
||||
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
|
||||
}
|
||||
}
|
||||
|
||||
struct soft_max_params {
|
||||
|
||||
int64_t nheads;
|
||||
uint32_t n_head_log2;
|
||||
int64_t ncols;
|
||||
int64_t nrows_x;
|
||||
int64_t nrows_y;
|
||||
int64_t ne00;
|
||||
int64_t ne01;
|
||||
int64_t ne02;
|
||||
int64_t ne03;
|
||||
int64_t nb11;
|
||||
int64_t nb12;
|
||||
int64_t nb13;
|
||||
|
||||
int64_t ne12;
|
||||
int64_t ne13;
|
||||
float scale;
|
||||
float max_bias;
|
||||
float m0;
|
||||
float m1;
|
||||
};
|
||||
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic push
|
||||
#pragma clang diagnostic ignored "-Wpass-failed"
|
||||
#endif // __clang__
|
||||
template <bool use_shared, int ncols_template, int block_size_template, typename T>
|
||||
static __global__ void soft_max_f32(
|
||||
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
|
||||
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
|
||||
const int64_t i03 = blockIdx.z;
|
||||
const int64_t i02 = blockIdx.y;
|
||||
const int64_t i01 = blockIdx.x;
|
||||
|
||||
//TODO: noncontigous inputs/outputs
|
||||
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
|
||||
|
||||
const int64_t i11 = i01;
|
||||
const int64_t i12 = i02 % p.ne12;
|
||||
const int64_t i13 = i03 % p.ne13;
|
||||
|
||||
x += int64_t(rowx)*ncols;
|
||||
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
|
||||
dst += int64_t(rowx)*ncols;
|
||||
|
||||
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int lane_id = threadIdx.x % WARP_SIZE;
|
||||
|
||||
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
|
||||
|
||||
extern __shared__ float data_soft_max_f32[];
|
||||
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
|
||||
// shared memory buffer to cache values between iterations:
|
||||
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
|
||||
|
||||
float max_val = sinks ? sinks[i02] : -INFINITY;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
|
||||
|
||||
vals[col] = val;
|
||||
max_val = max(max_val, val);
|
||||
}
|
||||
|
||||
// find the max value in the block
|
||||
max_val = warp_reduce_max(max_val);
|
||||
if (block_size > WARP_SIZE) {
|
||||
if (warp_id == 0) {
|
||||
buf_iw[lane_id] = -INFINITY;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf_iw[warp_id] = max_val;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
max_val = buf_iw[lane_id];
|
||||
max_val = warp_reduce_max(max_val);
|
||||
}
|
||||
|
||||
float tmp = 0.0f; // partial sum
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
break;
|
||||
}
|
||||
|
||||
const float val = expf(vals[col] - max_val);
|
||||
tmp += val;
|
||||
vals[col] = val;
|
||||
}
|
||||
|
||||
// find the sum of exps in the block
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
if (block_size > WARP_SIZE) {
|
||||
__syncthreads();
|
||||
if (warp_id == 0) {
|
||||
buf_iw[lane_id] = 0.0f;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (lane_id == 0) {
|
||||
buf_iw[warp_id] = tmp;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
tmp = buf_iw[lane_id];
|
||||
tmp = warp_reduce_sum(tmp);
|
||||
}
|
||||
|
||||
if (sinks) {
|
||||
tmp += expf(sinks[i02] - max_val);
|
||||
}
|
||||
|
||||
const float inv_sum = 1.0f / tmp;
|
||||
|
||||
#pragma unroll
|
||||
for (int col0 = 0; col0 < ncols; col0 += block_size) {
|
||||
const int col = col0 + tid;
|
||||
|
||||
if (ncols_template == 0 && col >= ncols) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst[col] = vals[col] * inv_sum;
|
||||
}
|
||||
}
|
||||
#ifdef __clang__
|
||||
#pragma clang diagnostic pop
|
||||
#endif // __clang__
|
||||
|
||||
template<int... Ns, typename T>
|
||||
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
|
||||
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
|
||||
{
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
auto launch_kernel = [=](auto I) -> bool {
|
||||
constexpr int ncols = decltype(I)::value;
|
||||
constexpr int block = (ncols > 1024 ? 1024 : ncols);
|
||||
|
||||
if (p.ncols == ncols) {
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
|
||||
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
|
||||
(x, mask, sinks, dst, p);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
// unary fold over launch_kernel
|
||||
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
|
||||
return;
|
||||
}
|
||||
|
||||
//default case
|
||||
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
|
||||
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
|
||||
int nth = WARP_SIZE;
|
||||
const int64_t ncols_x = params.ncols;
|
||||
|
||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||
const dim3 block_dims(nth, 1, 1);
|
||||
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
|
||||
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
|
||||
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
|
||||
|
||||
|
||||
const int id = ggml_cuda_get_device();
|
||||
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
|
||||
|
||||
|
||||
if (nbytes_shared <= smpbo) {
|
||||
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
|
||||
} else {
|
||||
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
|
||||
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
|
||||
}
|
||||
}
|
||||
|
||||
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
const ggml_tensor * src2 = dst->src[2];
|
||||
|
||||
const float * src0_d = (const float *) src0->data;
|
||||
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
|
||||
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
|
||||
float * dst_d = (float *) dst->data;
|
||||
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
||||
|
||||
const int64_t nrows_x = ggml_nrows(src0);
|
||||
const int64_t nrows_y = src0->ne[1];
|
||||
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
|
||||
float scale = 1.0f;
|
||||
float max_bias = 0.0f;
|
||||
|
||||
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
|
||||
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
|
||||
|
||||
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
|
||||
|
||||
const int64_t nb11 = src1 ? src1->nb[1] : 1;
|
||||
const int64_t nb12 = src1 ? src1->nb[2] : 1;
|
||||
const int64_t nb13 = src1 ? src1->nb[3] : 1;
|
||||
|
||||
const int64_t ne12 = src1 ? src1->ne[2] : 1;
|
||||
const int64_t ne13 = src1 ? src1->ne[3] : 1;
|
||||
|
||||
const uint32_t n_head = src0->ne[2];
|
||||
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
|
||||
|
||||
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
|
||||
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
|
||||
|
||||
|
||||
soft_max_params params = {};
|
||||
params.nheads = src0->ne[2];
|
||||
params.n_head_log2 = n_head_log2;
|
||||
params.ncols = ne00;
|
||||
params.nrows_x = nrows_x;
|
||||
params.nrows_y = nrows_y;
|
||||
params.ne00 = src0->ne[0];
|
||||
params.ne01 = src0->ne[1];
|
||||
params.ne02 = src0->ne[2];
|
||||
params.ne03 = src0->ne[3];
|
||||
params.nb11 = nb11;
|
||||
params.nb12 = nb12;
|
||||
params.nb13 = nb13;
|
||||
params.ne12 = ne12;
|
||||
params.ne13 = ne13;
|
||||
params.scale = scale;
|
||||
params.max_bias = max_bias;
|
||||
params.m0 = m0;
|
||||
params.m1 = m1;
|
||||
|
||||
if (use_f16) {
|
||||
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
} else {
|
||||
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -470,3 +470,79 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
|
||||
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
|
||||
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
||||
|
||||
if (i >= k) {
|
||||
return;
|
||||
}
|
||||
|
||||
// perform base op and multiply with gate (either offset in same tensor or a separate one)
|
||||
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||
|
||||
float xi = x[j0];
|
||||
float gi = g[j1];
|
||||
xi = fminf(xi, limit);
|
||||
gi = fmaxf(fminf(gi, limit), -limit);
|
||||
|
||||
float out_glu = xi / (1.0f + expf(-xi * alpha));
|
||||
out_glu = out_glu * (1.0f + gi);
|
||||
|
||||
dst[i] = out_glu;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
|
||||
const int64_t num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
|
||||
swiglu_oai_kernel<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
|
||||
}
|
||||
|
||||
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||
const ggml_tensor * src0 = dst->src[0];
|
||||
const ggml_tensor * src1 = dst->src[1];
|
||||
void * src0_d = src0->data;
|
||||
void * src1_d = src1 ? src1->data : src0->data;
|
||||
const int64_t src0_o = src0->nb[1];
|
||||
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||
void * dst_d = dst->data;
|
||||
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||
cudaStream_t stream = ctx.stream();
|
||||
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
||||
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(src0->type == dst->type);
|
||||
GGML_ASSERT(dst->ne[0] == nc);
|
||||
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
||||
|
||||
if (src1) {
|
||||
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
||||
GGML_ASSERT(src1->ne[0] == nc);
|
||||
GGML_ASSERT(src0->type == src1->type);
|
||||
}
|
||||
|
||||
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
|
||||
const float * op_params = (const float *)dst->op_params;
|
||||
const float alpha = op_params[2];
|
||||
const float limit = op_params[3];
|
||||
|
||||
float * src0_p = (float *) src0_d;
|
||||
float * src1_p = (float *) src1_d;
|
||||
|
||||
if (!src1) {
|
||||
src0_p += swapped ? nc : 0;
|
||||
src1_p += swapped ? 0 : nc;
|
||||
}
|
||||
|
||||
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc,
|
||||
src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
|
||||
}
|
||||
|
||||
|
||||
@@ -47,3 +47,5 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
|
||||
int64_t nelements, const float * x, const float * y, float * z);
|
||||
|
||||
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||
|
||||
Reference in New Issue
Block a user