Enable CUDA graphs for MoE models + GPT-OSS support (#689)

* gmp-oss: common

* gpt-oss: attnetion sinks, swiglu_oai

* gpt-oss: WIP llama

Model loads and runs (CPU only), but PPL is much to high
(~1500 for 1st batch vs ~200 in mainline).
Is it because of SWA, because of vocab, or did I introduce a bug somewhere?

* gpt-oss: CPU seems to be working

It was the SWA thta was missing in the previous commit.

There are issues with EOG tokens, so this still needs to be added.

* CUDA: ADD_ID

Just a copy from mainline

* gpt-oss: Seems to be working on CUDA

* gpt-oss: add sinks to the attn-vec kernels

* CUDA: add head size of 64 to new mma

Haven't turned it on yet, but observe slightly better PP and slightly
worse TG performance with that.

* gpt-oss: add ability to use -fmoe (only CUDA for now)

* Move row sums to the write place

* Add sinks to iqk flash attention

* gpt_oss: Implement -fmoe on the CPU

* Simdify swiglu_oai

Turning it off for now as performance becomes more variable,
so perhaps I'm running into thermal trottling imore often
because of making the CPU work too hard.

* llama: factor out model loader

* Builds successfully

* It runs, but mmap does not work

* Fix llama_mmap so mmap works

* Minor

* Fix CUDA after latest changes

* Attempt to use CUDA graphs with MoE models - not working

* CUDA graphs WIP - still not working

* CUDA graphs - seems to be working

Likely not all MLA variants are working.
I no longer remember why I added the q8_0 cpy that
transposes the tensor, but if really needed, this is now
missing. Also missing is q6_0.

* Make q8_0 cache work for DeepSeek models with CUDA graphs

* cuda: cpy for q6_0

* Fix llama_mmap on non-Linux platforms

* Adding forgotten file

* Iterating on Windows build failures

* cuda: re-add q8_0 -> q8_0 transpose

so mla = 2 can be used with CUDA graphs and q8_0 cache.

* Disable graphs without -fmoe

* Minor

* Turn graphs on by default

---------

Co-authored-by: Iwan Kawrakow <iwan.kawrakow@gmail.com>
This commit is contained in:
Kawrakow
2025-08-15 09:18:07 +03:00
committed by GitHub
parent c00335684c
commit 633e0617b0
56 changed files with 8720 additions and 5115 deletions

View File

@@ -132,7 +132,7 @@ set (GGML_CUDA_MIN_BATCH_OFFLOAD "32" CACHE STRING
option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF)
option(GGML_CUDA_NO_VMM "ggml: do not try to use CUDA VMM" OFF)
option(GGML_CUDA_FA_ALL_QUANTS "ggml: compile all quants for FlashAttention" OFF)
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" OFF)
option(GGML_CUDA_USE_GRAPHS "ggml: use CUDA graphs (llama.cpp only)" ON)
option(GGML_IQK_FLASH_ATTENTION "ggml: enable the IQK FlashAttention CPU kernels" ON)
option(GGML_IQK_FA_ALL_QUANTS "ggml: compile all quants for IQK FlashAttention" OFF)

View File

@@ -325,6 +325,16 @@
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_TERNARY_OP_LOCALS \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \
GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \
GGML_TENSOR_LOCALS(int64_t, ne2, src2, ne) \
GGML_TENSOR_LOCALS(size_t, nb2, src2, nb) \
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
#define GGML_TENSOR_BINARY_OP_LOCALS01 \
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \
@@ -571,6 +581,7 @@ extern "C" {
GGML_OP_DUP,
GGML_OP_ADD,
GGML_OP_ADD_ID,
GGML_OP_ADD1,
GGML_OP_ACC,
GGML_OP_SUB,
@@ -674,6 +685,7 @@ extern "C" {
GGML_UNARY_OP_HARDSWISH,
GGML_UNARY_OP_HARDSIGMOID,
GGML_UNARY_OP_SWIGLU,
GGML_UNARY_OP_SWIGLU_OAI,
GGML_UNARY_OP_COUNT,
};
@@ -1028,6 +1040,13 @@ extern "C" {
struct ggml_tensor * b,
enum ggml_type type);
// dst[i0, i1, i2] = a[i0, i1, i2] + b[i0, ids[i1, i2]]
GGML_API struct ggml_tensor * ggml_add_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids);
GGML_API struct ggml_tensor * ggml_add1(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1268,6 +1287,13 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
GGML_API struct ggml_tensor * ggml_swiglu_oai(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float alpha,
float limit);
// a - x
// b - dy
GGML_API struct ggml_tensor * ggml_silu_back(
@@ -1370,6 +1396,16 @@ extern "C" {
struct ggml_tensor * ids,
enum ggml_unary_op op);
GGML_API struct ggml_tensor * ggml_moe_up_gate_ext(
struct ggml_context * ctx,
struct ggml_tensor * a_up,
struct ggml_tensor * a_gate,
struct ggml_tensor * b,
struct ggml_tensor * ids,
struct ggml_tensor * a_up_b,
struct ggml_tensor * a_gate_b,
enum ggml_unary_op op);
// A: m columns, n rows,
// B: p columns, n rows,
// result is m columns, p rows
@@ -1662,6 +1698,11 @@ extern "C" {
float scale,
float max_bias);
GGML_API void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
GGML_API struct ggml_tensor * ggml_soft_max_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
@@ -1998,6 +2039,10 @@ extern "C" {
struct ggml_tensor * a,
enum ggml_prec prec);
GGML_API void ggml_flash_attn_ext_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks);
// TODO: needs to be adapted to ggml_flash_attn_ext
GGML_API struct ggml_tensor * ggml_flash_attn_back(
struct ggml_context * ctx,

View File

@@ -43,6 +43,7 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
case GGML_OP_DIAG_MASK_ZERO:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_SUB:
case GGML_OP_MUL:

View File

@@ -37,6 +37,8 @@
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/add-id.cuh"
#include "ggml-cuda/graph.cuh"
#include <algorithm>
#include <array>
@@ -49,6 +51,7 @@
#include <map>
#include <memory>
#include <mutex>
#include <condition_variable>
#include <stdint.h>
#include <stdio.h>
#include <stdarg.h>
@@ -77,6 +80,7 @@ GGML_API void ggml_backend_cuda_log_set_callback(ggml_log_callback log_callback,
#define GGML_CUDA_LOG_INFO(...) ggml_cuda_log(GGML_LOG_LEVEL_INFO, __VA_ARGS__)
#define GGML_CUDA_LOG_WARN(...) ggml_cuda_log(GGML_LOG_LEVEL_WARN, __VA_ARGS__)
#define GGML_CUDA_LOG_ERROR(...) ggml_cuda_log(GGML_LOG_LEVEL_ERROR, __VA_ARGS__)
#define GGML_CUDA_LOG_DEBUG(...) ggml_cuda_log(GGML_LOG_LEVEL_DEBUG, __VA_ARGS__)
GGML_ATTRIBUTE_FORMAT(2, 3)
static void ggml_cuda_log(enum ggml_log_level level, const char * format, ...) {
@@ -444,6 +448,35 @@ std::unique_ptr<ggml_cuda_pool> ggml_backend_cuda_context::new_pool_for_device(i
return std::unique_ptr<ggml_cuda_pool>(new ggml_cuda_pool_leg(device));
}
static std::mutex ggml_cuda_lock;
static std::condition_variable ggml_cuda_lock_cv;
static std::atomic<int> ggml_cuda_lock_counter;
ggml_backend_cuda_context::ggml_backend_cuda_context(int device) :
device(device), name(GGML_CUDA_NAME + std::to_string(device)) {
}
ggml_backend_cuda_context::~ggml_backend_cuda_context() {
std::unique_lock<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_cv.wait(lock, []{ return ggml_cuda_lock_counter.load(std::memory_order_relaxed) == 0; });
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
// cuda buffer
struct ggml_backend_cuda_buffer_context {
@@ -2220,6 +2253,24 @@ static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_origin
}
}
//static __global__ void k_quick_add(uint32_t n, uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
//
// for (uint32_t j = threadIdx.x; j < n; j += blockDim.x) {
// dst[j] = src1[j] + src2[j % n_per_row];
// }
//}
static __global__ void k_quick_add(uint32_t n_per_row, const float * src1, const float * src2, float * dst) {
uint32_t row = blockIdx.x;
const float * src1_row = src1 + row*n_per_row;
float * dst_row = dst + row*n_per_row;
for (uint32_t j = threadIdx.x; j < n_per_row; j += blockDim.x) {
dst_row[j] = src1_row[j] + src2[j];
}
}
static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n_as, int64_t n_ids,
const ggml_tensor * ids, std::vector<int>& moe_counts, std::vector<int>& cum_moe_counts,
ggml_cuda_pool_alloc<mmid_row_mapping>& dev_row_mapping) {
@@ -2270,7 +2321,7 @@ static inline bool prepare_row_mappigs(ggml_backend_cuda_context& ctx, int64_t n
return is_ser;
}
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
static bool ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];
@@ -2319,7 +2370,25 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
0, src0->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
return;
if (next && next->op == GGML_OP_MUL_MAT_ID && next->src[0]->type == src0->type && src1 == next->src[1] &&
ggml_are_same_shape(src0, next->src[0]) &&
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
ggml_backend_buffer_is_cuda(next->buffer) &&
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer)) {
ggml_backend_cuda_buffer_context * next_src0_ctx = (ggml_backend_cuda_buffer_context *) next->src[0]->buffer->context;
ggml_backend_cuda_buffer_context * next_dst_ctx = (ggml_backend_cuda_buffer_context *) next->buffer->context;
if (next_src0_ctx->device == device_id &&
next_dst_ctx->device == device_id) {
local_dst.data = next->data;
ggml_cuda_op_mul_mat_vec_q_id(ctx, next->src[0], &local_src1, ids, &local_dst,
(const char *)next->src[0]->data, nullptr, src1_quantized.get(), (float *)next->data,
0, src0->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
return true;
}
}
return false;
}
}
@@ -2356,7 +2425,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
if (ne12 == 1) {
if (false && ne12 == 1) {
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
@@ -2442,6 +2511,7 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
}
}
}
return false;
}
static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * next) {
@@ -2470,6 +2540,8 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
src0_2_ctx->device == device_id &&
src1_ctx->device == device_id &&
dst_ctx->device == device_id) {
//printf("%s(%s, %s): %ld x %ld x %ld, %ld x %ld x %ld, %ld x %ld x %ld\n", __func__, src0_1->name, src0_2->name,
// src0->ne[0], src0->ne[1], src0->ne[2], src1->ne[0], src1->ne[1], src1->ne[2], ids->ne[0], ids->ne[1], ids->ne[2]);
// Fast TG path
const int64_t n_ids = ids->ne[0];
auto stream = ctx.stream(device_id, 0);
@@ -2505,12 +2577,26 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
0, src0_1->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
if (dst->src[4]) {
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[4]->data,
(const int32_t *)ids->data, (float *)local_dst.data,
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
local_dst.nb[1], local_dst.nb[2], dst->src[4]->nb[1], ids->nb[2], stream);
}
local_dst.data = dst_gate_contiguous.get();
ggml_cuda_op_mul_mat_vec_q_id(ctx, src0_2, &local_src1, ids, &local_dst,
(const char *)src0_2->data, (const float *)src1->data, src1_quantized.get(), (float *)dst_gate_contiguous.get(),
0, src0_2->ne[1], 1, src1_padded_col_size, stream);
CUDA_CHECK(cudaGetLastError());
if (dst->src[5]) {
ggml_cuda_add_id((const float *)local_dst.data, (const float *)dst->src[5]->data,
(const int32_t *)ids->data, (float *)local_dst.data,
local_dst.ne[0], local_dst.ne[2], local_dst.ne[1], local_dst.ne[0], local_dst.ne[2],
local_dst.nb[1], local_dst.nb[2], dst->src[5]->nb[1], ids->nb[2], stream);
}
if (next && next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type) &&
ggml_backend_buffer_is_cuda(next->src[0]->buffer) &&
!ggml_backend_buffer_is_cuda_split(next->src[0]->buffer) &&
@@ -2518,8 +2604,15 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
ggml_backend_buffer_is_cuda(next->buffer) &&
((ggml_backend_cuda_buffer_context *)next->buffer->context)->device == device_id) {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], dst->ne[0]*n_ids,
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, unary_op, dst->ne[0]*n_ids,
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
}
CUDA_CHECK(cudaGetLastError());
const int64_t dst_padded_col_size = GGML_PAD(dst->ne[0], MATRIX_ROW_PADDING);
@@ -2555,8 +2648,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
return true;
} else {
CUDA_CHECK(cudaMemsetAsync(dst->data, 0, ggml_nbytes(dst), stream));
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(dst),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst->data, dst->ne[0]*n_ids, dst->ne[0], dst->ne[0], dst->ne[0], 1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, unary_op, ggml_nelements(dst),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst->data);
}
CUDA_CHECK(cudaGetLastError());
return false;
}
@@ -2624,7 +2723,7 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
final_src.nb[3] = final_src.nb[2];
}
if (ne12 == 1) {
if (false && ne12 == 1) {
ggml_cuda_pool_alloc<char> dst_up_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
ggml_cuda_pool_alloc<char> dst_gate_contiguous(ctx.pool(), sizeof(float)*dst_row.ne[0]);
if (fuse_down) {
@@ -2761,6 +2860,14 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
CUDA_CHECK(cudaGetLastError());
if (dst->src[4]) {
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
(const float *)((const char *)dst->src[4]->data + i02*dst->src[4]->nb[1]), (float *)dst_row.data);
CUDA_CHECK(cudaGetLastError());
}
dst_row.data = dst_gate_contiguous.get();
if (use_quantized_src1) {
ggml_cuda_op_mul_mat_q(ctx, &src0_2_row, &src1_row, &dst_row, (const char *)src0_2_row.data, nullptr, src1_quantized.get(), (float *)dst_row.data,
@@ -2770,8 +2877,24 @@ static bool ggml_cuda_up_gate_unary(ggml_backend_cuda_context & ctx, ggml_tensor
}
CUDA_CHECK(cudaGetLastError());
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(), (float *)dst_gate_contiguous.get());
if (dst->src[5]) {
dim3 block_dims(std::min(uint32_t(dst_row.ne[0]), 768u));
dim3 grid_dims(num_src1_rows);
k_quick_add<<<grid_dims, block_dims, 0, stream>>>(dst_row.ne[0], (const float *)dst_row.data,
(const float *)((const char *)dst->src[5]->data + i02*dst->src[5]->nb[1]), (float *)dst_row.data);
CUDA_CHECK(cudaGetLastError());
}
auto unary_op = (ggml_unary_op)dst->op_params[0];
if (unary_op == GGML_UNARY_OP_SWIGLU_OAI) {
ggml_swiglu_oai_cuda_f32((const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get(), ggml_nelements(&dst_row), dst_row.ne[0], dst_row.ne[0], dst_row.ne[0],
1.702f, 7.0f, stream);
} else {
ggml_fused_mul_unary(ctx, (ggml_unary_op)dst->op_params[0], ggml_nelements(&dst_row),
(const float *)dst_gate_contiguous.get(), (const float *)dst_up_contiguous.get(),
(float *)dst_gate_contiguous.get());
}
CUDA_CHECK(cudaGetLastError());
if (fuse_down) {
@@ -2851,6 +2974,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_ADD:
ggml_cuda_op_add(ctx, dst);
break;
case GGML_OP_ADD_ID:
ggml_cuda_op_add_id(ctx, dst);
break;
case GGML_OP_MULTI_ADD:
ggml_cuda_op_multi_add(ctx, dst);
break;
@@ -2877,6 +3003,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_UNARY_OP_SWIGLU:
ggml_cuda_op_swiglu(ctx, dst);
break;
case GGML_UNARY_OP_SWIGLU_OAI:
ggml_cuda_op_swiglu_oai(ctx, dst);
break;
case GGML_UNARY_OP_GELU_QUICK:
ggml_cuda_op_gelu_quick(ctx, dst);
break;
@@ -2938,7 +3067,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
}
break;
case GGML_OP_MUL_MAT_ID:
ggml_cuda_mul_mat_id(ctx, dst);
skip_next = ggml_cuda_mul_mat_id(ctx, dst, next);
break;
case GGML_OP_MOE_FUSED_UP_GATE:
skip_next = ggml_cuda_up_gate_unary(ctx, dst, next);
@@ -3119,6 +3248,105 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
}
#ifdef USE_CUDA_GRAPH
static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool use_cuda_graph) {
// Loop over nodes in GGML graph to obtain info needed for CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.clear();
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj";
const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased";
const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased";
const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased";
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
if (node->src[0] && node->src[0]->buffer && ggml_backend_buft_is_cuda_split(node->src[0]->buffer->buft)) {
use_cuda_graph = false; // Split buffers are not supported by CUDA graph capture
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to split buffer\n", __func__);
#endif
}
if (node->op == GGML_OP_MUL_MAT_ID && (node->ne[2] != 1 || node->src[2]->ne[0] != 1)) {
use_cuda_graph = false; // This node type is not supported by CUDA graph capture
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s(%s): disabling CUDA graphs due to unsupported node type %ld %ld\n",
__func__, node->src[0]->name, node->ne[2], node->src[2]->ne[0]);
#endif
}
if (node->op == GGML_OP_MOE_FUSED_UP_GATE) {
auto src0_1 = node->src[0];
auto src0_2 = node->src[1];
auto src1 = node->src[2];
if (src1->ne[1] != 1 || src1->ne[2] != 1 || src1->ne[3] != 1 || src1->type != GGML_TYPE_F32 ||
!ggml_is_quantized(src0_1->type) || !ggml_is_quantized(src0_2->type)) {
use_cuda_graph = false;
} else {
if (i < cgraph->n_nodes-1) {
auto next = cgraph->nodes[i+1];
if (next->op == GGML_OP_MUL_MAT_ID && ggml_is_quantized(next->src[0]->type)) {
++i;
}
}
}
}
if (node->op == GGML_OP_ADD &&
node->src[1] && node->src[1]->ne[1] > 1 &&
(node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) &&
(node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) &&
strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 &&
strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) {
// disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation
// by means of matching node names. See
// https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and
// https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773,
// Generally, changes in batch size or context size can cause changes to the grid size of some kernels.
use_cuda_graph = false;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]);
#endif
}
if (node->op == GGML_OP_CPY) {
// Store the pointers which are updated for each token, such that these can be sent
// to the device and accessed using indirection from CUDA graph
cuda_ctx->cuda_graph->cpy_dest_ptrs.push_back((char *) node->src[1]->data);
// store a pointer to each copy op CUDA kernel to identify it later
void * ptr = ggml_cuda_cpy_fn(node->src[0], node->src[1]);
if (!ptr) {
use_cuda_graph = false;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported copy op\n", __func__);
#endif
}
}
if (!use_cuda_graph) {
break;
}
}
if (use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = true;
// copy pointers to GPU so they can be accessed via indirection within CUDA graph
ggml_cuda_cpy_dest_ptrs_copy(cuda_ctx->cuda_graph.get(), cuda_ctx->cuda_graph->cpy_dest_ptrs.data(), cuda_ctx->cuda_graph->cpy_dest_ptrs.size(), cuda_ctx->stream());
}
return use_cuda_graph;
}
static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
graph_node_properties->node_address = node->data;
graph_node_properties->node_op = node->op;
@@ -3129,6 +3357,7 @@ static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_p
for (int i = 0; i < GGML_MAX_SRC; i++) {
graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
}
memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
}
static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
@@ -3160,9 +3389,246 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
return false;
}
}
if (node->op == GGML_OP_SCALE &&
memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false;
}
return true;
}
static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->instance == nullptr) {
cuda_graph_update_required = true;
}
// Check if the graph size has changed
if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
cuda_graph_update_required = true;
cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
}
// Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) {
bool has_matching_properties = true;
if (!cuda_graph_update_required) {
has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
if (!has_matching_properties) {
cuda_graph_update_required = true;
}
set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
}
return cuda_graph_update_required;
}
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
#else
cudaGraphNode_t errorNode;
cudaGraphExecUpdateResult result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#endif // CUDART_VERSION >= 12000
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
#endif
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clear error and re-instantiate
(void)cudaGetLastError();
CUDA_CHECK(cudaGraphExecDestroy(cuda_ctx->cuda_graph->instance));
cuda_ctx->cuda_graph->instance = nullptr;
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
} else {
GGML_ASSERT(stat == cudaSuccess);
}
}
#endif
static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu
// TODO
const bool integrated = false; //ggml_cuda_info().devices[cuda_ctx->device].integrated;
while (!graph_evaluated_or_captured) {
// Only perform the graph execution if CUDA graphs are not enabled, or we are capturing the graph.
// With the use of CUDA graphs, the execution will be performed by the graph launch.
if (!use_cuda_graph || cuda_graph_update_required) {
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
ggml_tensor * next = i < cgraph->n_nodes-1 ? cgraph->nodes[i+1] : nullptr;
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
continue;
}
#if 0
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) {
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL }, {})) {
ggml_cuda_op_rms_norm_fused(*cuda_ctx, node, cgraph->nodes[i+1]);
i++;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_SCALE, GGML_OP_UNARY, GGML_OP_SCALE }, { GGML_UNARY_OP_TANH })) {
i += 2;
ggml_cuda_op_softcap(*cuda_ctx, cgraph->nodes[i], node);
continue;
}
}
#endif
#ifndef NDEBUG
assert(node->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device));
for (int j = 0; j < GGML_MAX_SRC; j++) {
if (node->src[j] != nullptr) {
assert(node->src[j]->buffer);
//assert(node->src[j]->buffer->buft == ggml_backend_cuda_buffer_type(cuda_ctx->device) ||
// ggml_backend_buft_is_cuda_split(node->src[j]->buffer->buft) || (integrated && ggml_backend_buft_is_cuda_host(node->src[j]->buffer->buft)));
}
}
#else
GGML_UNUSED(integrated);
#endif // NDEBUG
bool skip_next = false;
bool ok = ggml_cuda_compute_forward(*cuda_ctx, node, next, skip_next);
if (!ok) {
GGML_CUDA_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
}
GGML_ASSERT(ok);
if (skip_next) ++i;
}
}
#ifdef USE_CUDA_GRAPH
if (use_cuda_graph && cuda_graph_update_required) { // End CUDA graph capture
if (cuda_ctx->cuda_graph->graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(cuda_ctx->cuda_graph->graph));
cuda_ctx->cuda_graph->graph = nullptr;
}
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cuda_ctx->cuda_graph->graph));
graph_evaluated_or_captured = true; // CUDA graph has been captured
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
if (ggml_cuda_lock_counter.fetch_sub(1, std::memory_order_relaxed) == 1) {
ggml_cuda_lock_cv.notify_all();
}
} else {
graph_evaluated_or_captured = true; // ggml graph has been directly evaluated
}
}
if (use_cuda_graph) {
if (cuda_ctx->cuda_graph->instance == nullptr) { // Create executable graph from captured graph.
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
}
if (cuda_graph_update_required) { // Update graph executable
update_cuda_graph_executable(cuda_ctx);
}
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
#else
graph_evaluated_or_captured = true;
#endif // USE_CUDA_GRAPH
}
}
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
#ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
// Objects required for CUDA Graph
if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
}
bool use_cuda_graph = true;
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
}
}
// Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
// or previous graph capture failure.
// Also disable for multi-gpu for now. TO DO investigate
if (disable_cuda_graphs_due_to_env
|| cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
|| cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
use_cuda_graph = false;
}
if (use_cuda_graph) {
cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
use_cuda_graph = check_node_graph_compatibility_and_refresh_copy_ops(cuda_ctx, cgraph, use_cuda_graph);
// Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
cuda_ctx->cuda_graph->number_consecutive_updates++;
} else {
cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG
GGML_CUDA_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
}
}
if (use_cuda_graph && cuda_graph_update_required) {
// Start CUDA graph capture
{
std::lock_guard<std::mutex> lock(ggml_cuda_lock);
ggml_cuda_lock_counter.fetch_add(1, std::memory_order_relaxed);
}
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
}
if (!use_cuda_graph) {
cuda_ctx->cuda_graph->use_cpy_indirection = false;
}
#else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#endif // USE_CUDA_GRAPH
bool graph_evaluated_or_captured = false;
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
return GGML_STATUS_SUCCESS;
}
/*
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
@@ -3431,6 +3897,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
return GGML_STATUS_SUCCESS;
}
*/
GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
@@ -3440,6 +3907,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_SWIGLU_OAI:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_SIGMOID:
case GGML_UNARY_OP_HARDSIGMOID:
@@ -3629,6 +4097,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_MULTI_ADD:
case GGML_OP_MUL:
case GGML_OP_DIV:

View File

@@ -0,0 +1,72 @@
#include "add-id.cuh"
static __global__ void add_id_kernel(
const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne0, int64_t ne1,
size_t nb01, size_t nb02,
size_t nb11,
size_t nb21
) {
const int64_t i1 = blockIdx.x;
const int64_t i2 = blockIdx.y;
const int i11 = *(int32_t *) ((char *) src2 + i1*sizeof(int32_t) + i2*nb21);
const size_t nb1 = ne0 * sizeof(float);
const size_t nb2 = ne1 * nb1;
float * dst_row = (float *)((char *)dst + i1*nb1 + i2*nb2);
const float * src0_row = (const float *)((char *)src0 + i1*nb01 + i2*nb02);
const float * src1_row = (const float *)((char *)src1 + i11*nb11);
for (int64_t i0 = threadIdx.x; i0 < ne0; i0 += blockDim.x) {
dst_row[i0] = src0_row[i0] + src1_row[i0];
}
}
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(nb00 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
GGML_ASSERT(nb20 == sizeof(int32_t));
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
const int32_t * src2_d = (const int32_t *)src2->data;
float * dst_d = (float *)dst->data;
int threads = std::min((int)ne00, 768); // cols
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
add_id_kernel<<<blocks, threads, 0, ctx.stream()>>>(
src0_d, src1_d, src2_d, dst_d,
ne0, ne1,
nb01, nb02,
nb11,
nb21
);
}
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne00, int64_t ne01, int64_t ne02,
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream) {
int threads = std::min((int)ne00, 768); // cols
dim3 blocks(ne01, ne02); // n_experts_used, n_tokens
add_id_kernel<<<blocks, threads, 0, stream>>>(
src0, src1, src2, dst,
ne0, ne1,
nb01, nb02,
nb11,
nb21
);
}

View File

@@ -0,0 +1,8 @@
#include "common.cuh"
void ggml_cuda_op_add_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_add_id(const float * src0, const float * src1, const int32_t * src2, float * dst,
int64_t ne00, int64_t ne01, int64_t ne02,
int64_t ne0, int64_t ne1, size_t nb01, size_t nb02, size_t nb11, size_t nb21, cudaStream_t stream);

View File

@@ -108,6 +108,23 @@ static const char * cu_get_error_str(CUresult err) {
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif
#if !defined(GGML_USE_HIPBLAS) && !defined(GGML_USE_MUSA)
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
static bool shared_memory_limit_raised[GGML_CUDA_MAX_DEVICES] = { false }; \
const int id = ggml_cuda_get_device(); \
if (!shared_memory_limit_raised[id]) { \
CUDA_CHECK(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, nbytes)); \
shared_memory_limit_raised[id] = true; \
} \
} while (0)
#else
# define CUDA_SET_SHARED_MEMORY_LIMIT(kernel, nbytes) \
do { \
GGML_UNUSED(nbytes); \
} while (0)
#endif // !(defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#if CUDART_VERSION >= 11100 || defined(GGML_USE_MUSA)
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else
@@ -808,37 +825,7 @@ struct ggml_tensor_extra_gpu {
#define USE_CUDA_GRAPH
#endif
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
};
struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() {
if (instance != nullptr) {
CUDA_CHECK(cudaGraphExecDestroy(instance));
}
if (graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(graph));
}
}
cudaGraph_t graph = nullptr;
cudaGraphExec_t instance = nullptr;
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
std::vector<char **> updated_kernel_arg;
#endif
};
struct ggml_cuda_graph;
struct ggml_backend_cuda_context {
int device;
@@ -850,26 +837,9 @@ struct ggml_backend_cuda_context {
std::unique_ptr<ggml_cuda_graph> cuda_graph;
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
explicit ggml_backend_cuda_context(int device);
~ggml_backend_cuda_context() {
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
~ggml_backend_cuda_context();
cudaStream_t stream(int device, int stream) {
if (streams[device][stream] == nullptr) {

View File

@@ -0,0 +1,262 @@
#pragma once
#include "ggml-common.h"
template<typename src_t, typename dst_t>
static __device__ __forceinline__ void convert_flt(const src_t * src, dst_t * dst) {
if constexpr (std::is_same_v<src_t, dst_t>) {
*dst = *src;
} else {
*dst = float(*src);
}
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void quantize_f32_q4_0_block(const float * __restrict__ x, block_q4_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q4_1_block(const float * __restrict__ x, block_q4_1 * __restrict__ y) {
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = x[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (x[0 + j] - vmin)*id;
const float x1 = (x[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
y->qs[j] = xi0;
y->qs[j] |= xi1 << 4;
}
}
static __device__ void quantize_f32_q5_0_block(const float * __restrict__ x, block_q5_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q5_1_block(const float * __restrict__ x, block_q5_1 * __restrict__ y) {
float min = x[0];
float max = x[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = x[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
y->dm.x = d;
y->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (x[0 + j] - min)*id;
const float x1 = (x[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(y->qh, &qh, sizeof(qh));
}
static __device__ void quantize_f32_q8_0_block(const float * __restrict__ x, block_q8_0 * __restrict__ y) {
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = x[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = x[j]*id;
y->qs[j] = roundf(x0);
}
}
static __device__ void quantize_f32_iq4_nl_block(const float * __restrict__ x, block_iq4_nl * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = x[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = x[0 + j]*id;
const float x1 = x[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
y->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = x[0 + j]*x[0 + j];
const float w1 = x[QK4_NL/2 + j]*x[QK4_NL/2 + j];
sumqx += w0*v0*x[j] + w1*v1*x[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
y->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
static __device__ void quantize_f32_q6_0_block(const float * __restrict__ xi, block_q6_0 * __restrict__ y) {
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK6_0; ++j) {
const float v = xi[j];
const float av = fabsf(xi[j]);
if (amax < av) {
amax = av;
vmax = v;
}
}
const float d = vmax / -32;
const float id = d ? 1.0f/d : 0.0f;
y->d = d;
memset(y->qh, 0, QK6_0/4);
for (int j = 0; j < QK6_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id;
const uint8_t xi0 = min(63, (int8_t)(x0 + 32.5f));
const uint8_t xi1 = min(63, (int8_t)(x1 + 32.5f));
y->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
const uint8_t h = (xi0 >> 4) | ((xi1 >> 4) << 2);
y->qh[j%(QK6_0/4)] |= (h << 4*(j/(QK6_0/4)));
}
}
// Wrapper functions for cpy.cu compatibility
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
quantize_f32_q4_0_block((const float *)cxi, (block_q4_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
quantize_f32_q4_1_block((const float *)cxi, (block_q4_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
quantize_f32_q5_0_block((const float *)cxi, (block_q5_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
quantize_f32_q5_1_block((const float *)cxi, (block_q5_1 *)cdsti);
}
static __device__ void cpy_blck_f32_q6_0(const char * cxi, char * cdsti) {
quantize_f32_q6_0_block((const float *)cxi, (block_q6_0 *)cdsti);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
quantize_f32_q8_0_block((const float *)cxi, (block_q8_0 *)cdsti);
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
quantize_f32_iq4_nl_block((const float *)cxi, (block_iq4_nl *)cdsti);
}
template<typename src_t, typename dst_t>
static __device__ void cpy_1_flt(const char * cxi, char * cdsti) {
convert_flt((const src_t *)cxi, (dst_t *)cdsti);
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,9 +1,11 @@
#include "common.cuh"
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_CPY_BLOCK_SIZE 64
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection = false);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_cpy_dest_ptrs_copy(ggml_cuda_graph * cuda_graph, char ** host_dest_ptrs, const int host_dest_ptrs_size, cudaStream_t stream);

View File

@@ -86,6 +86,24 @@ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const in
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q6_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q6_0 * x = (const block_q6_0 *) vx;
const dfloat d = x[ib].d;
const uint8_t h = x[ib].qh[iqs%8] >> 2*(iqs/8);
v.x = ((x[ib].qs[iqs] & 0xf) | ((h & 0x3) << 4));
v.y = ((x[ib].qs[iqs] >> 4) | ((h & 0xc) << 2));
#ifdef GGML_CUDA_F16
v = __hsub2(v, {32.0f, 32.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 32.0f) * d;
v.y = (v.y - 32.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;

View File

@@ -22,6 +22,7 @@ typedef void (* fattn_kernel_t)(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -747,6 +748,7 @@ void launch_fattn(
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst;
@@ -837,6 +839,7 @@ void launch_fattn(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *) sinks->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, softcap, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
@@ -1008,7 +1011,8 @@ void launch_fattn_mma(
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst;
@@ -1162,6 +1166,7 @@ void launch_fattn_mma(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *)sinks->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],

View File

@@ -425,6 +425,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half2 * const __restrict__ mask_h2,
const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
@@ -584,6 +585,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
// If attention sinks are used, potentially re-scale if KQ_max is small.
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
// so it's being done unconditionally for every thread.
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
KQ_max_scale[col] = expf(KQ_max_diff);
KQ_max[col] = KQ_max_new;
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
const float KQ_max_add = expf(sink - KQ_max_new);
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
}
if (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < D/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
} else {
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
for (int i = 0; i < D/tile_C_VKQ_16::J; ++i) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
}
}
}
}
}
// Write VKQ accumulators to shared memory in column-major format.
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// Also for np > 1 the combination is done via these values in shared memory.
@@ -823,6 +870,7 @@ static __global__ void flash_attn_mma_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -896,6 +944,7 @@ static __global__ void flash_attn_mma_ext_f16(
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
const float * sinks_f = sinks ? (const float *) sinks + channel * ncols2 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
@@ -906,12 +955,12 @@ static __global__ void flash_attn_mma_ext_f16(
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -934,6 +983,7 @@ static __global__ void flash_attn_mma_ext_f16(
const half2 * V_h2 = (const half2 *) (V + nb12*(channel*ncols2 / gqa_ratio)); // K and V have same shape
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * D/2);
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
@@ -943,10 +993,10 @@ static __global__ void flash_attn_mma_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask); GGML_UNUSED(sinks);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);

View File

@@ -43,37 +43,37 @@ struct fattn_mma_f16_config;
// Perhaps the 256 head size needs a closer look
// to see if this implementation is better.
//
//template <>
//struct fattn_mma_f16_config< 64, 64> {
// static constexpr int nbatch_fa = 64;
// static constexpr int nwarps_max = 4;
// static constexpr bool Q_in_reg = true;
// static constexpr int nstages_target = 2;
//
// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
// return 32;
// }
//
// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
// return 32;
// }
//
// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
// return 32;
// }
//
// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
// return 32;
// }
//
// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
// return 32;
// }
//
// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
// return 32;
// }
//};
template <>
struct fattn_mma_f16_config< 64, 64> {
static constexpr int nbatch_fa = 64;
static constexpr int nwarps_max = 4;
static constexpr bool Q_in_reg = true;
static constexpr int nstages_target = 2;
static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) {
return 32;
}
static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) {
return 32;
}
static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) {
return 32;
}
};
//
//template <>
//struct fattn_mma_f16_config< 80, 80> {
@@ -493,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
(V_h2 + k_VKQ_0*stride_V, tile_V, nbatch_V2, stride_V);
} else {
constexpr bool use_cp_async = nstages == 1;
if constexpr (ncols2 > 1 || mask_h2) {
if (ncols2 > 1 || mask_h2) {
flash_attn_ext_f16_load_mask<ncols1, nwarps, c::nbatch_fa, use_cp_async>(mask_h2 + k_VKQ_0/2, tile_mask, stride_mask);
}
}
@@ -576,7 +576,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
float KQ_rowsum_add[cols_per_thread] = {0.0f};
if constexpr (ntiles == 1) {
if constexpr (ncols2 > 1 || mask_h2) {
if (ncols2 > 1 || mask_h2) {
#pragma unroll
for (int i00 = 0; i00 < c::nbatch_fa; i00 += np*tile_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*tile_C_KQ::I;
@@ -818,6 +818,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half2 * const __restrict__ mask_h2,
const float * const __restrict__ sinks_f,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
@@ -975,6 +976,52 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
}
// If attention sinks are used, potentially re-scale if KQ_max is small.
// Also add the sink as a value to KQ_rowsum, this is done after synchonization of KQ_rowsum
// so it's being done unconditionally for every thread.
if (!is_fixup && (np == 1 || threadIdx.y % np == 0) && sinks_f) {
float KQ_max_scale[cols_per_thread];
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
static_assert(ntiles == 1 || ntiles == 2, "ntiles > 2 not implemented");
const int jc = ntiles == 1 ? 2*tile_C_VKQ::get_j(col/2) + col % 2 : tile_C_VKQ_16::get_i(col);
const float sink = sinks_f[jc % ncols2];
const float KQ_max_new = fmaxf(KQ_max[col], sink);
const float KQ_max_diff = KQ_max[col] - KQ_max_new;
KQ_max_scale[col] = expf(KQ_max_diff);
KQ_max[col] = KQ_max_new;
*((uint32_t *) &KQ_max_scale[col]) *= KQ_max_diff >= SOFTMAX_FTZ_THRESHOLD;
const float KQ_max_add = expf(sink - KQ_max_new);
KQ_rowsum[col] = KQ_max_scale[col]*KQ_rowsum[col] + KQ_max_add;
}
if (ntiles == 1) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[1]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < tile_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
} else {
#pragma unroll
for (int col = 0; col < cols_per_thread; ++col) {
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[col], KQ_max_scale[col]);
#pragma unroll
for (int i = 0; i < DV/tile_C_VKQ_16::J; ++i) {
#pragma unroll
for (int l0 = 0; l0 < tile_C_VKQ_16::ne; l0 += 2) {
VKQ_C_16[i*ntiles/2 + col/2].x[l0 + col % 2] *= KQ_max_scale_h2;
}
}
}
}
}
// Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8/4 threads each, does not need full reduce.
{
@@ -1222,7 +1269,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
}
}
#else
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2); GGML_UNUSED(sinks_f);
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
@@ -1239,6 +1286,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -1323,6 +1371,7 @@ static __global__ void flash_attn_ext_f16(
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1335,12 +1384,12 @@ static __global__ void flash_attn_ext_f16(
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
}
@@ -1362,6 +1411,7 @@ static __global__ void flash_attn_ext_f16(
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
const half2 * mask_h2 = ncols2 > 1 || mask ? (const half2 *) mask + (nb31/sizeof(half2))*jt*ncols1 : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
const float * sinks_f = sinks ? (const float *) sinks + channel*ncols2 : nullptr;
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
@@ -1373,7 +1423,7 @@ static __global__ void flash_attn_ext_f16(
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
(Q_f2, K_h2, V_h2, mask_h2, sinks_f, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
@@ -1535,7 +1585,8 @@ static void launch_fattn_new_mma(
const ggml_tensor * K = dst->src[1];
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
ggml_tensor * KQV = dst;
@@ -1709,6 +1760,7 @@ static void launch_fattn_new_mma(
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
sinks ? ((const char *)sinks->data) : nullptr,
!stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, logit_softcap, n_head_log2,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
@@ -1853,6 +1905,11 @@ static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2(ggml_backend_cuda_con
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (use_gqa_opt && gqa_ratio % 16 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 16>(ctx, dst);
return;
}
if (use_gqa_opt && gqa_ratio % 8 == 0) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<DKQ, DV, 8>(ctx, dst);
return;
@@ -1878,8 +1935,6 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
const ggml_tensor * V = dst->src[2];
const ggml_tensor * mask = dst->src[3];
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
float max_bias = 0.0f;
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
@@ -1888,6 +1943,12 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
const int gqa_ratio = Q->ne[2] / K->ne[2];
if (K->ne[0] == 64 && V->ne[0] == 64) {
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst);
return;
}
GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512);
GGML_ASSERT(gqa_ratio % 16 == 0);
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<576, 512, 16>(ctx, dst);

View File

@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,

View File

@@ -20,6 +20,7 @@ static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,

View File

@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -71,6 +72,7 @@ static __global__ void flash_attn_vec_ext_f16(
V += nb22*(blockIdx.y / gqa_ratio);
const half * maskh = (const half *) mask + ne11*ic0;
const float * sinksf = (const float *) (sinks);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
@@ -270,6 +272,39 @@ static __global__ void flash_attn_vec_ext_f16(
__syncthreads();
}
if (sinksf) {
const half sink = __float2half(sinksf[blockIdx.y]);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
half kqmax_new_j = kqmax_shared[j][threadIdx.y];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half KQ_max_scale = hexp(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const half val = hexp(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= __half2half2(KQ_max_scale);
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);

View File

@@ -17,6 +17,7 @@ static __global__ void flash_attn_vec_ext_f32(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -69,6 +70,7 @@ static __global__ void flash_attn_vec_ext_f32(
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const float * sinksf = (const float *) (sinks);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
@@ -254,6 +256,39 @@ static __global__ void flash_attn_vec_ext_f32(
__syncthreads();
}
if (sinksf) {
const float sink = sinksf[blockIdx.y];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (threadIdx.x == 0) {
kqmax_shared[j][threadIdx.y] = fmaxf(kqmax[j], sink);
}
}
__syncthreads();
#pragma unroll
for (int j = 0; j < ncols; ++j) {
float kqmax_new_j = kqmax_shared[j][threadIdx.y];
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j] - kqmax_new_j);
kqmax[j] = kqmax_new_j;
const float val = expf(sink - kqmax[j]);
kqsum[j] = kqsum[j]*KQ_max_scale;
if (tid == 0) {
kqsum[j] += val;
}
VKQ[j] *= KQ_max_scale;
}
__syncthreads();
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
kqsum[j] = warp_reduce_sum(kqsum[j]);

View File

@@ -5,6 +5,8 @@
// SPDX-License-Identifier: MIT
//
// TODO: attention sinks !!!
#include "common.cuh"
#include "fattn-common.cuh"
@@ -22,6 +24,7 @@ static __global__ void flash_attn_ext_f16(
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
@@ -93,6 +96,7 @@ static __global__ void flash_attn_ext_f16(
const half * V_h = (const half *) (V + nb22*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const float * sinks_f = sinks ? (const float *)sinks + blockIdx.y : nullptr;
const int stride_Q = nb01 / sizeof(float);
const int stride_K = nb11 / sizeof(half);

View File

@@ -539,7 +539,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
// As mentioned above, the new new MMA is slower than then the new MMA.
// As mentioned above, the new-new MMA is slower then the new MMA.
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
//ggml_cuda_flash_attn_ext_mma_new(ctx, dst);
}

View File

@@ -0,0 +1,41 @@
#pragma once
struct ggml_graph_node_properties {
void * node_address;
ggml_op node_op;
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
};
struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() {
if (instance != nullptr) {
CUDA_CHECK(cudaGraphExecDestroy(instance));
}
if (graph != nullptr) {
CUDA_CHECK(cudaGraphDestroy(graph));
}
}
cudaGraph_t graph = nullptr;
cudaGraphExec_t instance = nullptr;
size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0;
std::vector<ggml_graph_node_properties> ggml_graph_properties;
bool use_cpy_indirection = false;
std::vector<char *> cpy_dest_ptrs;
char ** dest_ptrs_d;
int dest_ptrs_size = 0;
// Index to allow each cpy kernel to be aware of it's position within the graph
// relative to other cpy nodes.
int graph_cpynode_index = -1;
#endif
};

View File

@@ -19,7 +19,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
}
template <bool vals_smem, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
static __global__ void soft_max_f32_nosinks(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2, float cap_params0, float cap_params1, bool do_softcap) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
@@ -124,7 +124,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, float * dst
}
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
static void soft_max_f32_cuda_nosinks(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, float cap_params0, float cap_params1, bool do_softcap, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
@@ -142,39 +142,40 @@ static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, cons
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
soft_max_f32_nosinks<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2, cap_params0, cap_params1, do_softcap);
}
}
#if 0
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
@@ -205,13 +206,14 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
} else {
const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, 0, 0, false, stream);
}
}
#endif
void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
@@ -241,10 +243,283 @@ void ggml_cuda_op_soft_cap_max(ggml_backend_cuda_context & ctx, ggml_tensor * ds
if (use_f16) {
const half * src1_dd = (const half *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
} else {
const float * src1_dd = (const float *)src1_d;
soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
soft_max_f32_cuda_nosinks(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, params[0], params[1], params[2], params[3], true, stream);
}
}
struct soft_max_params {
int64_t nheads;
uint32_t n_head_log2;
int64_t ncols;
int64_t nrows_x;
int64_t nrows_y;
int64_t ne00;
int64_t ne01;
int64_t ne02;
int64_t ne03;
int64_t nb11;
int64_t nb12;
int64_t nb13;
int64_t ne12;
int64_t ne13;
float scale;
float max_bias;
float m0;
float m1;
};
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params p) {
const int ncols = ncols_template == 0 ? p.ncols : ncols_template;
const int tid = threadIdx.x;
const int64_t i03 = blockIdx.z;
const int64_t i02 = blockIdx.y;
const int64_t i01 = blockIdx.x;
//TODO: noncontigous inputs/outputs
const int rowx = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
const int64_t i11 = i01;
const int64_t i12 = i02 % p.ne12;
const int64_t i13 = i03 % p.ne13;
x += int64_t(rowx)*ncols;
mask += (i11*p.nb11 + i12*p.nb12 + i13*p.nb13) / sizeof(T) * (mask != nullptr);
dst += int64_t(rowx)*ncols;
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const float slope = get_alibi_slope(p.max_bias, i02, p.n_head_log2, p.m0, p.m1);
extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
float * vals = use_shared ? buf_iw + WARP_SIZE : dst;
float max_val = sinks ? sinks[i02] : -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const float val = x[col]*p.scale + (mask ? slope*t2f32(mask[col]) : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
float tmp = 0.0f; // partial sum
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const float val = expf(vals[col] - max_val);
tmp += val;
vals[col] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();
tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}
if (sinks) {
tmp += expf(sinks[i02] - max_val);
}
const float inv_sum = 1.0f / tmp;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
return;
}
dst[col] = vals[col] * inv_sum;
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template<int... Ns, typename T>
static void launch_soft_max_kernels(const float * x, const T * mask, const float * sinks, float * dst,
const soft_max_params & p, cudaStream_t stream, dim3 block_dims, dim3 block_nums, size_t nbytes_shared)
{
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
auto launch_kernel = [=](auto I) -> bool {
constexpr int ncols = decltype(I)::value;
constexpr int block = (ncols > 1024 ? 1024 : ncols);
if (p.ncols == ncols) {
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, ncols, block, T>), smpbo);
soft_max_f32<true, ncols, block><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, mask, sinks, dst, p);
return true;
}
return false;
};
// unary fold over launch_kernel
if ((launch_kernel(std::integral_constant<int, Ns>{}) || ...)) {
return;
}
//default case
CUDA_SET_SHARED_MEMORY_LIMIT((soft_max_f32<true, 0, 0, T>), smpbo);
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
}
template<typename T>
static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(params.ne01, params.ne02, params.ne03);
const size_t nbytes_shared = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const int id = ggml_cuda_get_device();
const size_t smpbo = ggml_cuda_info().devices[id].smpbo;
if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else {
const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2];
const float * src0_d = (const float *) src0->data;
const void * src1_d = src1 ? (const void *) src1->data : nullptr;
const void * src2_d = src2 ? (const void *) src2->data : nullptr;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
const int64_t ne00 = src0->ne[0];
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float));
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;
const int64_t nb13 = src1 ? src1->nb[3] : 1;
const int64_t ne12 = src1 ? src1->ne[2] : 1;
const int64_t ne13 = src1 ? src1->ne[3] : 1;
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
soft_max_params params = {};
params.nheads = src0->ne[2];
params.n_head_log2 = n_head_log2;
params.ncols = ne00;
params.nrows_x = nrows_x;
params.nrows_y = nrows_y;
params.ne00 = src0->ne[0];
params.ne01 = src0->ne[1];
params.ne02 = src0->ne[2];
params.ne03 = src0->ne[3];
params.nb11 = nb11;
params.nb12 = nb12;
params.nb13 = nb13;
params.ne12 = ne12;
params.ne13 = ne13;
params.scale = scale;
params.max_bias = max_bias;
params.m0 = m0;
params.m1 = m1;
if (use_f16) {
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
} else {
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
}
}

View File

@@ -470,3 +470,83 @@ void ggml_cuda_op_sqrt(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
sqrt_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
template <typename T>
static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, float alpha, float limit) {
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
// perform base op and multiply with gate (either offset in same tensor or a separate one)
const int64_t j0 = (i / n) * o0 + (i % n);
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
float xi = x[j0];
float gi = g[j1];
xi = fminf(xi, limit);
gi = fmaxf(fminf(gi, limit), -limit);
float out_glu = xi / (1.0f + expf(-xi * alpha));
out_glu = out_glu * (1.0f + gi);
dst[i] = out_glu;
}
template <typename T>
static void swiglu_oai_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
const int64_t num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
swiglu_oai_kernel<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1, alpha, limit);
}
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
void * src0_d = src0->data;
void * src1_d = src1 ? src1->data : src0->data;
const int64_t src0_o = src0->nb[1];
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
void * dst_d = dst->data;
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == dst->type);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
GGML_ASSERT(src1->ne[0] == nc);
GGML_ASSERT(src0->type == src1->type);
}
//const int32_t swapped = ((const int32_t *) dst->op_params)[1];
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
const float * op_params = (const float *)dst->op_params;
const float alpha = op_params[2];
const float limit = op_params[3];
float * src0_p = (float *) src0_d;
float * src1_p = (float *) src1_d;
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc,
src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream);
}
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream) {
swiglu_oai_cuda(x, g, dst, k, n, o0, o1, alpha, limit, stream);
}

View File

@@ -47,3 +47,9 @@ void ggml_fused_mul_unary(ggml_backend_cuda_context & ctx, ggml_unary_op op,
int64_t nelements, const float * x, const float * y, float * z);
void ggml_cuda_op_multi_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_swiglu_oai_cuda_f32(const float * x, const float * g, float * dst, const int64_t k, const int64_t n,
const int64_t o0, const int64_t o1, const float alpha, const float limit, cudaStream_t stream);

View File

@@ -2823,7 +2823,6 @@ inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t
inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i] + y[i]; }
inline static void ggml_vec_add1_f32(const int n, float * z, const float * x, const float v) { for (int i = 0; i < n; ++i) z[i] = x[i] + v; }
inline static void ggml_vec_acc_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] += x[i]; }
inline static void ggml_vec_acc1_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] += v; }
@@ -2834,6 +2833,19 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) {
int i = 0;
#if defined(__AVX2__)
for (; i + 7 < n; i += 8) {
__m256 vx = _mm256_loadu_ps(x + i);
__m256 vy = _mm256_loadu_ps(y + i);
__m256 vz = _mm256_add_ps(vx, vy);
_mm256_storeu_ps(z + i, vz);
}
#endif
for (; i < n; ++i) z[i] = x[i] + y[i];
}
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
assert(nrc == 1);
UNUSED(nrc);
@@ -4004,6 +4016,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"DUP",
"ADD",
"ADD_ID",
"ADD1",
"ACC",
"SUB",
@@ -4092,13 +4105,14 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK",
};
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"x",
"x+y",
"x[i]+y",
"x+y",
"view(x,nb,offset)+=y->x",
"x-y",
@@ -4187,7 +4201,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)",
};
static_assert(GGML_OP_COUNT == 81, "GGML_OP_COUNT != 81");
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@@ -4207,9 +4221,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
"HARDSWISH",
"HARDSIGMOID",
"SWIGLU",
"SWIGLU_OAI",
};
static_assert(GGML_UNARY_OP_COUNT == 14, "GGML_UNARY_OP_COUNT != 14");
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
@@ -5917,6 +5932,29 @@ struct ggml_tensor * ggml_add_cast(
return ggml_add_cast_impl(ctx, a, b, type);
}
// ggml_add_id
struct ggml_tensor * ggml_add_id(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * ids) {
GGML_ASSERT(a->ne[0] == b->ne[0]);
GGML_ASSERT(a->ne[1] == ids->ne[0]);
GGML_ASSERT(a->ne[2] == ids->ne[1]);
GGML_ASSERT(ids->type == GGML_TYPE_I32);
struct ggml_tensor * result = ggml_dup_tensor(ctx, a);
result->op = GGML_OP_ADD_ID;
result->src[0] = a;
result->src[1] = b;
result->src[2] = ids;
return result;
}
// ggml_add1
static struct ggml_tensor * ggml_add1_impl(
@@ -6662,6 +6700,36 @@ struct ggml_tensor * ggml_swiglu(
return result;
}
struct ggml_tensor * ggml_swiglu_oai(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
float alpha,
float limit) {
GGML_ASSERT(ggml_is_contiguous_1(a));
if (b) {
GGML_ASSERT(ggml_is_contiguous_1(b));
GGML_ASSERT(ggml_are_same_shape(a, b));
GGML_ASSERT(a->type == b->type);
}
int64_t ne[4] = {a->ne[0]/2, a->ne[1], a->ne[2], a->ne[3]};
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
result->op = GGML_OP_UNARY;
result->grad = NULL;
result->src[0] = a;
result->src[1] = b;
ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_SWIGLU_OAI);
ggml_set_op_params_f32(result, 2, alpha);
ggml_set_op_params_f32(result, 3, limit);
return result;
}
// ggml_silu_back
struct ggml_tensor * ggml_silu_back(
@@ -7017,6 +7085,66 @@ struct ggml_tensor * ggml_moe_up_gate(
result->src[1] = as_gate;
result->src[2] = b;
result->src[3] = ids;
result->src[4] = NULL;
result->src[5] = NULL;
ggml_set_op_params_i32(result, 0, (int32_t) op);
return result;
}
struct ggml_tensor * ggml_moe_up_gate_ext(
struct ggml_context * ctx,
struct ggml_tensor * as_up,
struct ggml_tensor * as_gate,
struct ggml_tensor * b,
struct ggml_tensor * ids,
struct ggml_tensor * as_up_b,
struct ggml_tensor * as_gate_b,
enum ggml_unary_op op) {
if (!as_up_b && !as_gate_b) {
return ggml_moe_up_gate(ctx, as_up, as_gate, b, ids, op);
}
if (as_up->type != as_gate->type || !ggml_are_same_shape(as_up, as_gate)) {
struct ggml_tensor * result_up = ggml_mul_mat_id(ctx, as_up, b, ids);
if (as_up_b) {
result_up = ggml_add_id(ctx, result_up, as_up_b, ids);
}
struct ggml_tensor * result_gate = ggml_mul_mat_id(ctx, as_gate, b, ids);
if (as_gate_b) {
result_gate = ggml_add_id(ctx, result_gate, as_gate_b, ids);
}
return ggml_fused_mul_unary(ctx, result_gate, result_up, op);
}
GGML_ASSERT(!ggml_is_transposed(as_up));
GGML_ASSERT(!ggml_is_transposed(as_gate));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
GGML_ASSERT(as_up->ne[3] == 1); // as is 3d (one matrix per expert)
GGML_ASSERT(b->ne[3] == 1); // b is 3d
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
GGML_ASSERT(as_up->ne[0] == b->ne[0]); // can_mul_mat
GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
GGML_ASSERT(as_up->ne[1] == as_up_b->ne[0]);
GGML_ASSERT(as_gate->ne[1] == as_gate_b->ne[0]);
bool is_node = false;
const int64_t ne[4] = { as_up->ne[1], ids->ne[0], b->ne[2], 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
result->op = GGML_OP_MOE_FUSED_UP_GATE;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = as_up;
result->src[1] = as_gate;
result->src[2] = b;
result->src[3] = ids;
result->src[4] = as_up_b;
result->src[5] = as_gate_b;
ggml_set_op_params_i32(result, 0, (int32_t) op);
@@ -7970,6 +8098,22 @@ struct ggml_tensor * ggml_soft_max_ext(
return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
}
void ggml_soft_max_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks) {
if (!sinks) {
a->src[2] = NULL;
return;
}
GGML_ASSERT(a->op == GGML_OP_SOFT_MAX);
GGML_ASSERT(a->src[2] == NULL);
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
a->src[2] = sinks;
}
// ggml_soft_max_back
static struct ggml_tensor * ggml_soft_max_back_impl(
@@ -8833,6 +8977,22 @@ void ggml_flash_attn_ext_set_prec(
ggml_set_op_params_i32(a, 3, prec_i32); // scale is on first pos, max_bias on second
}
void ggml_flash_attn_ext_add_sinks(
struct ggml_tensor * a,
struct ggml_tensor * sinks) {
if (!sinks) {
a->src[4] = NULL;
return;
}
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
GGML_ASSERT(a->src[4] == NULL);
GGML_ASSERT(a->src[0]->ne[2] == sinks->ne[0]);
GGML_ASSERT(sinks->type == GGML_TYPE_F32);
a->src[4] = sinks;
}
// ggml_flash_attn_back
struct ggml_tensor * ggml_flash_attn_back(
@@ -11497,6 +11657,77 @@ static void ggml_compute_forward_multi_add(
}
}
// ggml_compute_forward_add_id
static void ggml_compute_forward_add_id_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src2 = dst->src[2];
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(src2->type == GGML_TYPE_I32);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
const int nr = ggml_nrows(src0);
GGML_TENSOR_TERNARY_OP_LOCALS
GGML_ASSERT( nb0 == sizeof(float));
GGML_ASSERT(nb10 == sizeof(float));
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int ir = ir0; ir < ir1; ++ir) {
// src0 indices
const int i3 = ir/(ne2*ne1);
const int i2 = (ir - i3*ne2*ne1)/ne1;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1);
// src1 indices
const int i11 = *(int32_t *) ((char *) src2->data + i1*nb20 + i2*nb21);
GGML_ASSERT(i11 >= 0 && i11 < ne11);
ggml_vec_add_f32(ne0,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i11*nb11));
}
}
static void ggml_compute_forward_add_id(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_add_id_f32(params, dst);
} break;
default:
{
GGML_ABORT("unsupported type for ggml_compute_forward_add_id: %s", ggml_type_name(src0->type));
}
}
}
// ggml_compute_forward_add1
static void ggml_compute_forward_add1_f32(
@@ -13760,6 +13991,93 @@ static void ggml_compute_forward_swiglu(
}
}
// ggml_compute_forward_swiglu_oai
static void ggml_compute_forward_swiglu_oai_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
char * src0_d = (char *) src0->data;
char * src1_d = (char *) (src1 ? src1->data : src0->data);
const size_t src0_o = src0->nb[1];
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
GGML_ASSERT(ggml_is_contiguous_1(src0));
GGML_ASSERT(ggml_is_contiguous_1(dst));
if (src1) {
GGML_ASSERT(ggml_is_contiguous_1(src1));
GGML_ASSERT(src0->type == src1->type);
}
const int ith = params->ith;
const int nth = params->nth;
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
const int nr = ggml_nrows(src0);
GGML_ASSERT(dst->ne[0] == nc);
GGML_ASSERT(ggml_nrows(dst) == nr);
const int32_t swapped = false; //ggml_get_op_params_i32(dst, 1);
const float alpha = ggml_get_op_params_f32(dst, 2);
const float limit = ggml_get_op_params_f32(dst, 3);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
float * src0_p = (float *) (src0_d + i1*src0_o);
float * src1_p = (float *) (src1_d + i1*src1_o);
float * dst_p = (float *) ((char *) dst->data + i1*(dst->nb[1]));
if (!src1) {
src0_p += swapped ? nc : 0;
src1_p += swapped ? 0 : nc;
}
for (int k = 0; k < nc; k++) {
const float x = MIN(src0_p[k], limit);
const float y = MAX(MIN(src1_p[k], limit), -limit);
const float out_glu = x / (1.f + expf(alpha * (-x)));
dst_p[k] = out_glu * (y + 1.f);
}
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
const float x = dst_p[k];
GGML_UNUSED(x);
assert(!isnan(x));
assert(!isinf(x));
}
#endif
}
}
static void ggml_compute_forward_swiglu_oai(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
switch (src0->type) {
case GGML_TYPE_F32:
{
ggml_compute_forward_swiglu_oai_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
// ggml_compute_forward_fused_mul_unary
static void ggml_compute_forward_fused_mul_unary_f32(
@@ -15167,6 +15485,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const struct ggml_tensor * src1 = dst->src[2];
const struct ggml_tensor * ids = dst->src[3];
const struct ggml_tensor * up_b = dst->src[4];
const struct ggml_tensor * gate_b = dst->src[5];
const struct ggml_tensor * src0_1 = dst->src[0];
const struct ggml_tensor * src0_2 = dst->src[1];
const struct ggml_tensor * src0 = src0_1; // so GGML_TENSOR_BINARY_OP_LOCALS works
@@ -15191,6 +15511,9 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne13 == 1);
const size_t nb41 = up_b ? up_b->nb[1] : 0;
const size_t nb51 = up_b ? gate_b->nb[1] : 0;
// row groups
const int n_ids = ids->ne[0]; // n_expert_used
const int n_as = ne02; // n_expert
@@ -15278,6 +15601,8 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
const char * src0_1_cur = (const char *) src0_1->data + cur_a*nb02;
const char * src0_2_cur = (const char *) src0_2->data + cur_a*nb02;
const char * up_b_cur = up_b ? (const char *)up_b->data + cur_a*nb41 : NULL;
const char * gate_b_cur = gate_b ? (const char *)gate_b->data + cur_a*nb51 : NULL;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
@@ -15288,6 +15613,7 @@ static void ggml_compute_forward_mul_mat_id_up_gate(
if (!iqk_moe_fused_up_gate(nr0, nr1, ne00, ne11, dst->op_params[0],
type, src0_1_cur, src0_2_cur, nb01,
vec_dot_type, (const char *)wdata, row_size,
up_b_cur, gate_b_cur,
(float *)dst->data, nb1, nb2,
matrix_rows + cur_a*ne12, ith, nth)) GGML_ABORT("fatal error");
@@ -16645,6 +16971,7 @@ static void ggml_compute_forward_soft_max_f32(
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1];
const struct ggml_tensor * src2 = dst->src[2];
assert(ggml_is_contiguous(dst));
assert(ggml_are_same_shape(src0, dst));
@@ -16662,6 +16989,13 @@ static void ggml_compute_forward_soft_max_f32(
GGML_TENSOR_UNARY_OP_LOCALS
const int64_t nb11 = src1 ? src1->nb[1] : 1;
const int64_t nb12 = src1 ? src1->nb[2] : 1;
const int64_t nb13 = src1 ? src1->nb[3] : 1;
const int64_t ne12 = src1 ? src1->ne[2] : 1;
const int64_t ne13 = src1 ? src1->ne[3] : 1;
//const int64_t ne11 = src1 ? src1->ne[1] : 1;
// TODO: is this supposed to be ceil instead of floor?
@@ -16673,67 +17007,80 @@ static void ggml_compute_forward_soft_max_f32(
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
const int nc = src0->ne[0];
const int nr = ggml_nrows(src0);
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
for (int i1 = ir0; i1 < ir1; i1++) {
// ALiBi
const uint32_t h = (i1/ne01)%ne02; // head
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
// sinks
const float * sk = src2 ? (float *)((char *) src2->data) : NULL;
float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const int64_t i11 = i01;
const int64_t i12 = i02%ne12;
const int64_t i13 = i03%ne13;
// broadcast the mask across rows
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data) + (i1%ne01)*ne00 : NULL;
// ALiBi
const uint32_t h = i02; // head
const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
ggml_vec_cpy_f32 (nc, wp, sp);
ggml_vec_scale_f32(nc, wp, scale);
if (mp_f32) {
if (use_f16) {
for (int i = 0; i < nc; ++i) {
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
float * sp = (float *)((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
float * dp = (float *)((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
// broadcast the mask across rows
ggml_fp16_t * mp_f16 = src1 ? (ggml_fp16_t *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
float * mp_f32 = src1 ? (float *)((char *) src1->data + i11*nb11 + i12*nb12 + i13*nb13) : NULL;
ggml_vec_cpy_f32 (ne00, wp, sp);
ggml_vec_scale_f32(ne00, wp, scale);
if (mp_f32) {
if (use_f16) {
for (int i = 0; i < ne00; ++i) {
wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
}
} else {
for (int i = 0; i < ne00; ++i) {
wp[i] += slope*mp_f32[i];
}
}
}
} else {
for (int i = 0; i < nc; ++i) {
wp[i] += slope*mp_f32[i];
#ifndef NDEBUG
for (int i = 0; i < ne00; ++i) {
//printf("p[%d] = %f\n", i, p[i]);
assert(!isnan(wp[i]));
}
#endif
float max = -INFINITY;
ggml_vec_max_f32(ne00, &max, wp);
// if we have sinks, make a correction as if they were included in the softmax
if (sk) {
max = MAX(max, sk[i02]);
}
ggml_float sum = ggml_vec_soft_max_f32(ne00, dp, wp, max);
assert(sum > 0.0);
if (sk) {
sum += (ggml_float) expf(sk[i02] - max);
}
sum = 1.0/sum;
ggml_vec_scale_f32(ne00, dp, sum);
#ifndef NDEBUG
for (int i = 0; i < ne00; ++i) {
assert(!isnan(dp[i]));
assert(!isinf(dp[i]));
}
#endif
}
}
//#ifndef NDEBUG
// for (int i = 0; i < nc; ++i) {
// //printf("p[%d] = %f\n", i, p[i]);
// assert(!isnan(wp[i]));
// }
//#endif
float max = -INFINITY;
ggml_vec_max_f32(nc, &max, wp);
ggml_float sum = ggml_vec_soft_max_f32(nc, dp, wp, max);
//assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(nc, dp, sum);
//#ifndef NDEBUG
// for (int i = 0; i < nc; ++i) {
// assert(!isnan(dp[i]));
// assert(!isinf(dp[i]));
// }
//#endif
}
}
@@ -16755,7 +17102,6 @@ static void ggml_compute_forward_soft_max(
}
}
// ggml_compute_forward_soft_max_back
static void ggml_compute_forward_soft_max_back_f32(
@@ -18308,12 +18654,14 @@ static void ggml_compute_forward_argsort_thresh(
static void ggml_compute_forward_flash_attn_ext_f16(
const struct ggml_compute_params * params,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
const struct ggml_tensor * q = dst->src[0];
const struct ggml_tensor * k = dst->src[1];
const struct ggml_tensor * v = dst->src[2];
const struct ggml_tensor * mask = dst->src[3];
const struct ggml_tensor * sinks = dst->src[4];
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
@@ -18383,6 +18731,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
#if GGML_USE_IQK_MULMAT
// For now we do not implement sinks in the iqk FA implementation
if (iqk_flash_attn_noalibi(q->type, mask->type, max_bias,
q->ne[3], q->ne[2], q->nb[3], q->nb[2],
k->ne[3], k->ne[2], k->nb[3], k->nb[2],
@@ -18390,7 +18739,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
dst->ne[2], dst->ne[1], dst->nb[1],
k->type, v->type,
Dk, Dv, neq1, nek1, q->nb[1], k->nb[1], v->nb[1], mask->nb[1],
q->data, k->data, v->data, mask->data,
q->data, k->data, v->data, mask->data, sinks ? sinks->data : NULL,
scale, softcap, (float *)dst->data,
params->wdata, (barrier_t)ggml_barrier, (void *)params->shared, ith, nth)) return;
@@ -18447,6 +18796,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
ggml_vec_dot_t const kq_vec_dot = type_traits[k->type].vec_dot;
ggml_to_float_t const v_to_float = type_traits[v->type].to_float;
GGML_ASSERT(( q_to_vec_dot) && "fattn: unsupported K-type");
GGML_ASSERT((v->type == GGML_TYPE_F32 || v_to_float ) && "fattn: unsupported V-type");
const int64_t Dkv = MAX(Dk, Dv);
// loop over n_batch and n_head
@@ -18552,6 +18904,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
}
}
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
ggml_vec_scale_f32(Dv, VKQ32, ms);
} else {
vs = expf(s - M);
}
S = S*ms + vs;
}
// V /= S
const float S_inv = 1.0f/S;
ggml_vec_scale_f32(Dv, VKQ32, S_inv);
@@ -18571,17 +18939,13 @@ static void ggml_compute_forward_flash_attn_ext_f16(
static void ggml_compute_forward_flash_attn_ext(
const struct ggml_compute_params * params,
const struct ggml_tensor * q,
const struct ggml_tensor * k,
const struct ggml_tensor * v,
const struct ggml_tensor * mask,
struct ggml_tensor * dst) {
switch (dst->op_params[3]) {
case GGML_PREC_DEFAULT:
case GGML_PREC_F32:
{
// uses F32 accumulators
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
ggml_compute_forward_flash_attn_ext_f16(params, dst);
} break;
default:
{
@@ -19350,6 +19714,10 @@ static void ggml_compute_forward_unary(
{
ggml_compute_forward_swiglu(params, dst);
} break;
case GGML_UNARY_OP_SWIGLU_OAI:
{
ggml_compute_forward_swiglu_oai(params, dst);
} break;
case GGML_UNARY_OP_HARDSWISH:
{
ggml_compute_forward_hardswish(params, dst);
@@ -19898,6 +20266,10 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_add(params, tensor);
} break;
case GGML_OP_ADD_ID:
{
ggml_compute_forward_add_id(params, tensor);
} break;
case GGML_OP_ADD1:
{
ggml_compute_forward_add1(params, tensor);
@@ -20136,7 +20508,7 @@ static bool ggml_compute_forward(struct ggml_compute_params * params, struct ggm
} break;
case GGML_OP_FLASH_ATTN_EXT:
{
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
ggml_compute_forward_flash_attn_ext(params, tensor);
} break;
case GGML_OP_FLASH_ATTN_BACK:
{
@@ -20486,6 +20858,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
}
} break;
case GGML_OP_ADD_ID:
{
GGML_ABORT("fatal error"); // TODO: implement
} break;
case GGML_OP_ADD1:
{
if (src0->grad) {
@@ -21719,6 +22095,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_OP_DUP:
case GGML_OP_CONT:
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
case GGML_OP_ACC:
case GGML_OP_MULTI_ADD:
@@ -21758,6 +22135,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
case GGML_UNARY_OP_GELU_QUICK:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_SWIGLU:
case GGML_UNARY_OP_SWIGLU_OAI:
{
n_tasks = n_threads;
} break;
@@ -21952,6 +22330,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
}
} break;
case GGML_OP_ADD:
case GGML_OP_ADD_ID:
case GGML_OP_ADD1:
{
if (ggml_is_quantized(node->src[0]->type)) {

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_128_128) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<128, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
iqk_flash_helper_T<128, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (nk%128 == 0) {
return iqk_flash_helper_T<128, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<128, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
return iqk_flash_helper_T<128, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_192_128) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<192, 128, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
iqk_flash_helper_T<192, 128, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (nk%128 == 0) {
return iqk_flash_helper_T<192, 128, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<192, 128, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
return iqk_flash_helper_T<192, 128, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_256_256) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<256, 256, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
iqk_flash_helper_T<256, 256, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (nk%128 == 0) {
return iqk_flash_helper_T<256, 256, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<256, 256, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
return iqk_flash_helper_T<256, 256, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}

View File

@@ -9,7 +9,8 @@ namespace {
template <int step_k, typename KHelper, typename VHelper>
inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
const float * q, const char * mask, float scale, float softcap, float * qkv,
const float * sinkf, float * M, float * S) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n;
if (nq1 == 0) return true;
@@ -21,29 +22,29 @@ inline void iqk_deepseek_helper(KHelper& kh, VHelper& vh,
};
if (nq1 >= 16) {
int n_step = nq1/16;
FlashAttn<576, 512, 16, step_k> fa(scale, softcap);
FlashAttn<576, 512, 16, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(16*n_step)) return;
}
if (nq1 >= 8) {
int n_step = nq1/8;
FlashAttn<576, 512, 8, step_k> fa(scale, softcap);
FlashAttn<576, 512, 8, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(8*n_step)) return;
}
if (nq1 >= 4) {
int n_step = nq1/4;
FlashAttn<576, 512, 4, step_k> fa(scale, softcap);
FlashAttn<576, 512, 4, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(4*n_step)) return;
}
if (nq1 >= 2) {
int n_step = nq1/2;
FlashAttn<576, 512, 2, step_k> fa(scale, softcap);
FlashAttn<576, 512, 2, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
if (update(2*n_step)) return;
}
FlashAttn<576, 512, 1, step_k> fa(scale, softcap);
FlashAttn<576, 512, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, qkv, M, S);
}
@@ -51,37 +52,37 @@ template <int step_k>
inline bool iqk_deepseek_helper(ggml_type type_k,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) {
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
if (type_k == GGML_TYPE_Q8_0) {
HelperQ80 kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
if (type_k == GGML_TYPE_Q8_0_R8) {
HelperQ80R8<576> kh((const char *)k, stride_k);
HelperQ80 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
if (type_k == GGML_TYPE_Q6_0) {
HelperQ60 kh((const char *)k, stride_k);
HelperQ60 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#if GGML_IQK_FA_ALL_QUANTS
if (type_k == GGML_TYPE_Q8_KV) {
HelperQ8KV<576> kh((const char *)k, stride_k);
HelperQ8KV<512> vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (type_k == GGML_TYPE_F16) {
HelperF16 kh((const char *)k, stride_k);
HelperF16 vh((const char *)v, stride_v);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_deepseek_helper<step_k>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
return true;
}
#ifdef __AVX512BF16__
@@ -89,10 +90,10 @@ inline bool iqk_deepseek_helper(ggml_type type_k,
HelperBF16<576, step_k> kh((const char *)k, stride_k);
HelperBF16<512, step_k> vh((const char *)v, stride_v);
if (nq1 % 8 == 0) {
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap);
FlashAttnBF16<576, 512, 8, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else {
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap);
FlashAttnBF16<576, 512, 1, step_k> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
return true;
@@ -113,7 +114,7 @@ IQK_FA_CASE(iqk_fa_576_512) {
}
stride_q /= sizeof(float); // q stride as float
return iqk_deepseek_helper<32>(type_k, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, M, S);
q, (const char *)k, (const char *)v, (const char *)mask, scale, softcap, qkv, sinkf, M, S);
}

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_64_64) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<64, 64, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
iqk_flash_helper_T<64, 64, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (nk%128 == 0) {
return iqk_flash_helper_T<64, 64, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<64, 64, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
return iqk_flash_helper_T<64, 64, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}

View File

@@ -19,26 +19,26 @@ IQK_FA_CASE(iqk_fa_96_96) {
if (type_v != GGML_TYPE_BF16) return false; // we do not support mixing bf16 k-cache with other types
if (nk%64 == 0) {
iqk_flash_helper_T<96, 96, 64>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
iqk_flash_helper_T<96, 96, 32>(nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
return true;
}
#endif
if (nk%128 == 0) {
return iqk_flash_helper_T<96, 96, 128>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
if (nk%64 == 0) {
return iqk_flash_helper_T<96, 96, 64>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}
return iqk_flash_helper_T<96, 96, 32>(type_k, type_v, nq, nk, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, ck, cv, cm, scale, softcap, qkv, M, S);
q, ck, cv, cm, scale, softcap, qkv, sinkf, M, S);
}

View File

@@ -1141,10 +1141,25 @@ struct FlashQKV {
}
template <typename FMS>
inline void normalize_and_store_1row(const FMS& fms, int j, const qkv_cache_t * R, float * qkv) const {
inline void normalize_and_store_1row(const FMS& fms, int j, qkv_cache_t * R, float * qkv, const float * sinkf) const {
static_assert(q_step == FMS::q_step);
GGML_ASSERT(fms.S[j] > 0);
auto norm = F16::set1(1/fms.S[j]);
float S = fms.S[j];
if (sinkf) {
float s = *sinkf;
if (s > fms.M[j]) {
float m = expf(fms.M[j] - s);
auto vm = F16::set1(m);
for (int i = 0; i < D/F16::block_size; ++i) {
auto Ri = R + F16::block_size*i;
F16::store(Ri, F16::mul(vm, F16::load(Ri)));
}
S = S*m + 1;
} else {
S += expf(s - fms.M[j]);
}
}
GGML_ASSERT(S > 0);
auto norm = F16::set1(1/S);
//auto norm = F16::set1(fms.S[j] > 0 ? 1/fms.S[j] : 0.f);
for (int i = 0; i < D/F16::block_size; ++i) {
auto r = F16::load(R + F16::block_size*i);
@@ -1153,7 +1168,7 @@ struct FlashQKV {
}
template <typename FMS>
inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, float * M, float * S) const {
inline void normalize_and_store(const FMS& fms, int nq1, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) {
static_assert(q_step == FMS::q_step);
if (M && S) {
std::memcpy(M, fms.M, nq1*sizeof(float));
@@ -1173,7 +1188,7 @@ struct FlashQKV {
} else {
auto R = qkv_cache;
for (int j = 0; j < nq1; ++j) {
normalize_and_store_1row(fms, j, R, qkv);
normalize_and_store_1row(fms, j, R, qkv, sinkf);
qkv += stride_qkv;
R += D;
}
@@ -1181,7 +1196,7 @@ struct FlashQKV {
}
template <typename FMS>
inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, float * M, float * S) const {
inline void normalize_and_store(const FMS& fms, int stride_qkv, float * qkv, const float * sinkf, float * M, float * S) {
static_assert(q_step == FMS::q_step);
if (M && S) {
std::memcpy(M, fms.M, q_step*sizeof(float));
@@ -1201,7 +1216,7 @@ struct FlashQKV {
} else {
auto R = qkv_cache;
for (int j = 0; j < q_step; ++j) {
normalize_and_store_1row(fms, j, R, qkv);
normalize_and_store_1row(fms, j, R, qkv, sinkf);
qkv += stride_qkv;
R += D;
}
@@ -1332,7 +1347,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv,
float * M, float * S) {
const float * sinkf, float * M, float * S) {
#ifdef __aarch64__
float16_t q_f16[Dk*q_step];
#endif
@@ -1356,7 +1371,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
vh.next_block(k_step);
mr += k_step*sizeof(ggml_half);
}
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
q += q_step*stride_q;
mask += q_step*stride_m;
@@ -1383,7 +1398,7 @@ void compute_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, in
vh.next_block(k_step);
mr += k_step*sizeof(ggml_half);
}
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
}
}
@@ -1392,7 +1407,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
FlashMS<q_step, k_step>& fms,
FlashQKV<Dv, q_step, k_step>& fqkv,
const float * q, const char * mask, float * qkv,
float * M, float * S, char * qptr) {
const float * sinkf, float * M, float * S, char * qptr) {
auto q8 = (typename KHelper::block_q8 *)qptr;
if constexpr (q_step > 1 && std::is_same_v<KHelper, HelperQ80>) {
if (nq1 == q_step) {
@@ -1412,7 +1427,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
vh.next_block(k_step);
mr += k_step*sizeof(ggml_half);
}
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
return;
}
}
@@ -1449,10 +1464,10 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
}
#if FA_TIMING
t1 = Perf::cur_time();
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
perf.accum_nolock(3, t1);
#else
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
#endif
q += q_step*stride_q;
@@ -1474,7 +1489,7 @@ void compute_helper_q(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q,
vh.next_block(k_step);
mr += k_step*sizeof(ggml_half);
}
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
}
#if FA_TIMING
Perf::instance().add(perf);
@@ -1504,7 +1519,7 @@ struct FlashAttn {
static_assert(k_step%F16::block_size == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
FlashAttn(float scale, float softcap) : fms(scale, softcap) {}
FlashAttn(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {}
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
@@ -1533,7 +1548,7 @@ struct FlashAttn {
HelperQ80R8<Dk> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ80R8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
return;
}
@@ -1547,29 +1562,30 @@ struct FlashAttn {
HelperQ8KVR8<Dk> khr4(nk1, kh);
#endif
compute_helper_q<Dk, Dv, q_step, k_step, HelperQ8KVR8<Dk>, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
khr4, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
return;
}
#endif
}
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, qptr);
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, qptr);
}
else {
typename KHelper::block_q8 q8[q_step*(Dk/KHelper::block_size_q)];
compute_helper_q<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S, (char *)q8);
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S, (char *)q8);
}
}
else {
compute_helper<Dk, Dv, q_step, k_step, KHelper, VHelper, FlashQKfp32<Dk, q_step, k_step>>(
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, M, S);
kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, fms, fqkv, q, mask, qkv, sinkf, M, S);
}
}
FlashMS<q_step, k_step> fms;
FlashQKV<Dv, q_step, k_step> fqkv;
const float * sinkf;
};
@@ -1927,7 +1943,7 @@ struct FlashAttnBF16 {
static_assert(k_step%32 == 0);
static_assert(q_step <= 4 || q_step%4 == 0);
FlashAttnBF16(float scale, float softcap) : fms(scale, softcap) {}
FlashAttnBF16(float scale, float softcap, const float * sinkf) : fms(scale, softcap), sinkf(sinkf) {}
template <typename KHelper, typename VHelper>
void compute(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
@@ -1967,7 +1983,7 @@ struct FlashAttnBF16 {
#if FA_TIMING
t1 = Perf::cur_time();
#endif
fqkv.normalize_and_store(fms, stride_qkv, qkv, M, S);
fqkv.normalize_and_store(fms, stride_qkv, qkv, sinkf, M, S);
#if FA_TIMING
perf.accum_nolock(4, t1);
#endif
@@ -1990,7 +2006,7 @@ struct FlashAttnBF16 {
vh.next_block(k_step);
mr += k_step*sizeof(ggml_half);
}
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, M, S);
fqkv.normalize_and_store(fms, n_left, stride_qkv, qkv, sinkf, M, S);
}
#if FA_TIMING
Perf::instance().add(perf);
@@ -1999,12 +2015,14 @@ struct FlashAttnBF16 {
FlashMS<q_step, k_step> fms;
FlashQKV<Dv, q_step, k_step> fqkv;
const float * sinkf;
};
#endif
template <int Dk, int Dv, int k_step, typename KHelper, typename VHelper>
inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int stride_q, int stride_m, int stride_qkv,
const float * q, const char * mask, float scale, float softcap, float * qkv, float * M, float * S) {
const float * q, const char * mask, float scale, float softcap, float * qkv,
const float * sinkf, float * M, float * S) {
auto update = [&nq1, &mask, &q, &qkv, &M, &S, stride_q, stride_m, stride_qkv] (int n) {
nq1 -= n;
@@ -2018,48 +2036,48 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
if (nk1 >= 512) {
if (nq1 >= 128) {
int n_step = nq1/128;
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 128*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(128*n_step)) return;
}
if (nq1 >= 64) {
int n_step = nq1/64;
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap);
FlashAttn<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 64*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(64*n_step)) return;
}
if (nq1 >= 32) {
int n_step = nq1/32;
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap);
FlashAttn<Dk, Dv, 32, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 32*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(32*n_step)) return;
}
if (nq1 >= 16) {
int n_step = nq1/16;
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap);
FlashAttn<Dk, Dv, 16, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 16*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(16*n_step)) return;
}
}
if (nq1 >= 8) {
int n_step = nq1/8;
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap);
FlashAttn<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 8*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(8*n_step)) return;
}
else if (nq1 >= 4) {
int n_step = nq1/4;
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap);
FlashAttn<Dk, Dv, 4, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 4*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(4*n_step)) return;
}
else if (nq1 >= 2) {
int n_step = nq1/2;
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap);
FlashAttn<Dk, Dv, 2, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, 2*n_step, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
if (update(2*n_step)) return;
}
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap);
FlashAttn<Dk, Dv, 1, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
@@ -2067,26 +2085,26 @@ inline void iqk_flash_helper(KHelper& kh, VHelper& vh, int nq1, int nk1, int str
template <int Dk, int Dv, int k_step>
inline void iqk_flash_helper_T(int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) {
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
HelperBF16<Dk, k_step> kh(k, stride_k);
HelperBF16<Dv, k_step> vh(v, stride_v);
if (nk1 >= 4096) {
if (nq1 >= 64) {
FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap);
FlashAttnBF16<Dk, Dv, 64, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
else if (nq1 >= 16) {
FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap);
FlashAttnBF16<Dk, Dv, 16, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
return;
}
}
if (nq1 >= 8) {
FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap);
FlashAttnBF16<Dk, Dv, 8, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
} else {
FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap);
FlashAttnBF16<Dk, Dv, 1, k_step> fa(scale, softcap, sinkf);
fa.compute(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, (const char *)mask, qkv, M, S);
}
}
@@ -2096,43 +2114,43 @@ template <int Dk, int Dv, int k_step, typename KHelper>
inline bool iqk_flash_helper_T(KHelper& kh, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) {
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
switch (type_v) {
case GGML_TYPE_F16: {
HelperF16 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break;
#ifdef __AVX512BF16__
case GGML_TYPE_BF16: {
HelperBF16<Dv, k_step> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break;
#endif
case GGML_TYPE_Q8_0: {
HelperQ80 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dv> vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q4_0: {
HelperQ40 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_Q4_1: {
HelperQ41 vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4nl vh(v, stride_v);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, M, S);
iqk_flash_helper<Dk, Dv, k_step>(kh, vh, nq1, nk1, stride_q, stride_m, stride_qkv, q, mask, scale, softcap, qkv, sinkf, M, S);
} break;
#endif
default: return false;
@@ -2144,42 +2162,42 @@ template <int Dk, int Dv, int k_step>
inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
int nq1, int nk1, int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,
const float * q, const char * k, const char * v, const char * mask,
float scale, float softcap, float * qkv, float * M, float * S) {
float scale, float softcap, float * qkv, const float * sinkf, float * M, float * S) {
bool result = false;
switch (type_k) {
case GGML_TYPE_F16: {
HelperF16 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_Q8_0: {
HelperQ80 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_Q8_0_R8: {
HelperQ80R8<Dk> kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_Q6_0: {
HelperQ60 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break;
#if GGML_IQK_FA_ALL_QUANTS
case GGML_TYPE_Q8_KV: {
HelperQ8KV<Dk> kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_Q4_0: {
HelperQ40 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_Q4_1: {
HelperQ41 kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break;
case GGML_TYPE_IQ4_NL: {
HelperIQ4nl kh(k, stride_k);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, M, S);
result = iqk_flash_helper_T<Dk, Dv, k_step>(kh, type_v, nq1, nk1, stride_q, stride_v, stride_m, stride_qkv, q, v, mask, scale, softcap, qkv, sinkf, M, S);
} break;
#endif
default: break;
@@ -2194,7 +2212,7 @@ inline bool iqk_flash_helper_T(ggml_type type_k, ggml_type type_v,
int stride_q, int stride_k, int stride_v, int stride_m, int stride_qkv,\
const float * q, const void * k, const void * v, const void * mask,\
float scale, float softcap,\
float * qkv, float * M, float * S)
float * qkv, const float * sinkf, float * M, float * S)
IQK_FA_CASE(iqk_fa_576_512);
IQK_FA_CASE(iqk_fa_192_128);

View File

@@ -66,6 +66,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))
@@ -139,7 +140,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto work_this_thread = (float *)(result_buffer + ith*size_thread);
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, nq_this_thread, nek1/gcd_k, nbq2, stride_k, stride_v, 0, Dv, //Dk*sizeof(uint16_t), Dv,
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth,
(const float *)qth, (const void *)kth, (const void *)vth, (const void *)mth, nullptr, 0,
scale, softcap,
work_this_thread, work_this_thread + (Dv+0)*nq_this_thread, work_this_thread + (Dv+1)*nq_this_thread)) return false;
@@ -182,51 +183,6 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
if (neq3 == 1 && rk2 > 1 && rk2 == rv2 && neq1 == 1 && nth >= 1 && nek2*nek1 >= 32*nth) {
auto result_size = (Dv + 16)*rk2*sizeof(float);
int gcd = simple_gcd(nek2, nth);
if (false && gcd > 1) {
int nth_g = nth/gcd;
int ith_g = ith%nth_g;
int nek1_32 = nek1/32;
int nek1_pt = (nek1_32 + nth_g - 1)/nth_g;
int ith_mid = nth_g;
if (nek1_pt*nth_g > nek1_32) {
ith_mid = nek1_32 - nth_g*(nek1_pt - 1);
}
nek1_pt *= 32;
int nek1_mid = ith_mid*nek1_pt;
int nek1_thread = ith_g < ith_mid ? nek1_pt : nek1_pt - 32;
for (int ik02 = ith/nth_g; ik02 < nek2; ik02 += gcd) {
int ik01 = ith_g < ith_mid ? ith_g*nek1_pt : nek1_mid + (ith_g - ith_mid)*nek1_thread;
auto this_result = (float *)((char *)work_buffer + (ik02*nth_g + ith_g)*result_size);
auto this_q = (const float *)((const char *)q + ik02*rk2*nbq2);
auto this_k = (const char *)k + ik01*stride_k + ik02*nbk2;
auto this_v = (const char *)v + ik01*stride_v + ik02*nbv2;
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, nek1_thread, nbq2, stride_k, stride_v, 0, Dv,
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
}
barrier(barrier_data);
for (int iq2 = ith; iq2 < neq2; iq2 += nth) {
int ik02 = iq2/rk2;
int il = iq2 - ik02*rk2;
auto Racc = qkv + iq2*nb1/sizeof(float);
float M = -INFINITY, S = 0;
for (int ig = 0; ig < nth_g; ++ig) {
int istep_k = ik02*nth_g + ig;
auto this_result = (float *)((char *)work_buffer + istep_k*result_size);
const float * R = this_result + il*Dv;
const float * Mj = this_result + Dv*rk2;
const float * Sj = Mj + rk2;
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
}
return true;
}
int nth_k = nth/gcd;
int nek2_k = nek2/gcd;
int nchunk = nek2_k*nek1/32;
@@ -259,7 +215,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
auto this_m = (const char *)mask + ik01*sizeof(uint16_t); // we don't have ggml_half available here
if (!iqk_flash_attn_impl(int_type_k, int_type_v,
Dk, Dv, rk2, this_nk, nbq2, stride_k, stride_v, 0, Dv,
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m,
this_q, (const void *)this_k, (const void *)this_v, (const void *)this_m, nullptr, 0,
scale, softcap, this_result, this_result + (Dv+0)*rk2, this_result + (Dv+1)*rk2)) return false;
}
@@ -281,6 +237,16 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
const float * Sj = Mj + rk2;
accumulate_qkv(Dv, M, S, Mj[il], Sj[il], Racc, R);
}
if (sinks) {
float s = ((const float *)sinks)[iq2];
if (s > M) {
float m = expf(M - s);
for (int i = 0; i < Dv; ++i) Racc[i] *= m;
S = S*m + 1;
} else {
S += expf(s - M);
}
}
float norm = S > 0 ? 1/S : 1;
for (int i = 0; i < Dv; ++i) Racc[i] *= norm;
}
@@ -306,6 +272,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
int counter = 0;
for (int64_t iq3 = 0; iq3 < neq3; iq3++) {
for (int64_t iq2 = 0; iq2 < neq2; iq2++) {
auto sinksf = sinks ? (const float *)sinks + iq2 : nullptr;
if (counter++ % (nth/ntg) == ith/ntg) {
int iq1 = (ith%ntg)*neq1g;
int this_neq1 = std::min(neq1g, neq1-iq1);
@@ -314,7 +281,7 @@ extern "C" IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float
(const float *)((const char *)q + iq2*nbq2 + iq3*nbq3 + iq1*stride_q),
(const void *)((const char *)k + iq2/rk2*nbk2 + iq3/rk3*nbk3),
(const void *)((const char *)v + iq2/rv2*nbv2 + iq3/rv3*nbv3),
(const void *)((const char *)mask + iq1*stride_m),
(const void *)((const char *)mask + iq1*stride_m), sinksf, 1,
scale, softcap,
(float *)((char *)qkv + (iq3*ne2*ne1 + iq2 + iq1*ne1)*nb1), nullptr, nullptr)) return false;
}

View File

@@ -23,6 +23,8 @@ bool iqk_flash_attn_impl(int type_k, // type of k
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const float * sinksf, // attention sinks
int nsinks, // number of sinks
float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))

View File

@@ -120,16 +120,21 @@ struct MulMat {
funcs[n_left-1](n, vx, bx, info, nrc_x);
}
}
inline void gelu(int n, const float * src, float * dst);
inline void relu(int n, const float * src, float * dst);
inline void silu(int n, const float * src, float * dst);
inline void activate(ggml_unary_op op, int n, const float * src, float * dst) {
inline static void gelu(int n, const float * src, float * dst);
inline static void relu(int n, const float * src, float * dst);
inline static void silu(int n, const float * src, float * dst);
inline static void swiglu_oai(int n, const float * src, float * dst);
inline static void clamp_oai(int n, float *x);
inline static void activate(ggml_unary_op op, int n, const float * src, float * dst) {
if (op == GGML_UNARY_OP_GELU) gelu(n, src, dst);
else if (op == GGML_UNARY_OP_RELU) relu(n, src, dst);
else if (op == GGML_UNARY_OP_SILU) silu(n, src, dst);
else if (op == GGML_UNARY_OP_SWIGLU_OAI) swiglu_oai(n, src, dst);
else GGML_ABORT("fatal error");
}
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx, DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
inline void mul_mat_up_gate_NxM(int n, const void * vx_up, const void * vx_gate, size_t bx,
const float * up_b, const float * gate_b,
DataInfo& info, int nrc_x, int nrc_y, int unary_op) {
#ifdef __aarch64__
constexpr int k_x_step = 64; //8192; // Tiling does not seem to help on my M2 Max (but difference to tiling is small)
#else
@@ -137,6 +142,29 @@ struct MulMat {
#endif
auto op = ggml_unary_op(unary_op);
float tmp[k_x_step*16];
auto process = [&tmp, n, op, vx_gate, vx_up, gate_b, up_b, bx, xstep = k_x_step] (mul_mat_t func, const DataInfo& this_info, int ix, int this_nrc_x, int ny) {
func(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
if (gate_b) {
auto b = gate_b + ix;
auto x = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) x[j] += b[j];
}
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*xstep);
}
func(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
auto result = this_info.dst_row(ky);
if (up_b) {
auto b = up_b + ix;
for (int j = 0; j < this_nrc_x; ++j) result[j] += b[j];
}
if (op == GGML_UNARY_OP_SWIGLU_OAI) {
clamp_oai(this_nrc_x, result);
}
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*xstep + j];
}
};
if (func16 && nrc_y >= 16) {
int n_step = (nrc_y - info.cur_y)/16;
for (int ix = 0; ix < nrc_x; ix += k_x_step) {
@@ -144,15 +172,7 @@ struct MulMat {
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < n_step; ++iy) {
func16(n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < 16; ++ky) {
activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
}
func16(n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < 16; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(func16, this_info, ix, this_nrc_x, 16);
this_info.cur_y += 16;
}
}
@@ -175,23 +195,11 @@ struct MulMat {
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < my1; ++iy) {
funcs[ny1-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny1; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny1-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny1; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(funcs[ny1-1], this_info, ix, this_nrc_x, ny1);
this_info.cur_y += ny1;
}
for (int iy = 0; iy < my2; ++iy) {
funcs[ny2-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny2; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny2-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny2; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(funcs[ny2-1], this_info, ix, this_nrc_x, ny2);
this_info.cur_y += ny2;
}
}
@@ -203,13 +211,7 @@ struct MulMat {
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
for (int iy = 0; iy < n_step; ++iy) {
funcs[ny-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[ny-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < ny; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(funcs[ny-1], this_info, ix, this_nrc_x, ny);
this_info.cur_y += ny;
}
}
@@ -222,13 +224,7 @@ struct MulMat {
auto this_info = info;
this_info.s += ix;
int this_nrc_x = ix + k_x_step <= nrc_x ? k_x_step : nrc_x - ix;
funcs[n_left-1](n, (const void *)((const char *)vx_gate + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < n_left; ++ky) activate(op, this_nrc_x, this_info.dst_row(ky), tmp + ky*k_x_step);
funcs[n_left-1](n, (const void *)((const char *)vx_up + ix*bx), bx, this_info, this_nrc_x);
for (int ky = 0; ky < n_left; ++ky) {
auto result = this_info.dst_row(ky);
for (int j = 0; j < this_nrc_x; ++j) result[j] *= tmp[ky*k_x_step + j];
}
process(funcs[n_left-1], this_info, ix, this_nrc_x, n_left);
}
}
}
@@ -731,6 +727,7 @@ extern "C" IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB,
const char * up_b_c, const char * gate_b_c,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth) {
const mmid_row_mapping * row_mapping = (const mmid_row_mapping *)vrow_mapping;
@@ -774,7 +771,9 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
if (!iqk_convert_repack(typeA, ne00, (const char *)Agate + (first_x + ix)*strideA, strideA, Xg, ne00, this_nrc_x)) {
GGML_ABORT("Fatal error");
}
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, this_info, this_nrc_x, Ny, unary_op);
auto up_b = up_b_c ? (const float *)up_b_c + first_x + ix : nullptr;
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x + ix : nullptr;
mm.mul_mat_up_gate_NxM(ne00, Xu, Xg, row_size_qx, up_b, gate_b, this_info, this_nrc_x, Ny, unary_op);
}
return true;
@@ -795,7 +794,10 @@ extern "C" IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int n
nrc_x *= num_rows;
DataInfo info{C + first_x, (const char *)B, nb1/sizeof(float),
row_size_qy, 0, ne11, row_mapping, nb2/sizeof(float)};
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx, info, nrc_x, Ny, unary_op);
auto up_b = up_b_c ? (const float *)up_b_c + first_x : nullptr;
auto gate_b = gate_b_c ? (const float *)gate_b_c + first_x : nullptr;
mm.mul_mat_up_gate_NxM(ne00, (const char *)Aup + row_size_qx*first_x, (const char *)Agate + row_size_qx*first_x, row_size_qx,
up_b, gate_b, info, nrc_x, Ny, unary_op);
return true;
}
@@ -993,6 +995,46 @@ bool MulMat::prepare(int typeA, int typeB, int ne00, MulMat& m, int /*Ny*/) {
namespace {
// TODO: these swiglu_oai constants shouldn't be hard coded
constexpr float k_swiglu_oai_alpha = 1.702f;
constexpr float k_swiglu_oai_limit = 7.f;
void MulMat::swiglu_oai(int n, const float * x, float * y) {
// int i = 0;
//#if defined __AVX512F__ && defined __AVX512DQ__
// {
// auto max = _mm512_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm512_set1_ps(-k_swiglu_oai_alpha);
// for (; i + 15 < n; i += 16) {
// auto xc = v_clamp_max(_mm512_loadu_ps(x + i), max);
// _mm512_storeu_ps(y + i, v_silu_oai(xc, alpha));
// }
// }
//#endif
//#if defined __AVX2__ && defined __FMA__
// if (i + 7 < n) {
// auto max = _mm256_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm256_set1_ps(-k_swiglu_oai_alpha);
// for (; i + 7 < n; i += 8) {
// auto xc = v_clamp_max(_mm256_loadu_ps(x + i), max);
// _mm256_storeu_ps(y + i, v_silu_oai(xc, alpha));
// }
// }
//#endif
// for (; i < n; ++i) {
// auto xi = std::min(x[i], k_swiglu_oai_limit);
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
// }
for (int i = 0; i < n; ++i) {
auto xi = std::min(x[i], k_swiglu_oai_limit);
y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
}
}
void MulMat::clamp_oai(int n, float * x) {
for (int i = 0; i < n; ++i) x[i] = 1.f + std::max(std::min(x[i], k_swiglu_oai_limit), -k_swiglu_oai_limit);
}
#if defined(__ARM_NEON) && defined(__aarch64__)
void MulMat::gelu(int n, const float * x, float * y) {
constexpr float GELU_COEF_A = 0.044715f;
@@ -1040,6 +1082,37 @@ void MulMat::gelu(int n, const float * x, float * y) {
for (; i < n; ++i) y[i] = 0.5f*x[i]*(1.0f + tanhf(SQRT_2_OVER_PI*x[i]*(1.0f + GELU_COEF_A*x[i]*x[i])));
}
//void MulMat::swiglu_oai(int n, const float * x, float * y) {
// int i = 0;
//#if defined __AVX512F__ && defined __AVX512DQ__
// {
// auto limit = _mm512_set1_ps(k_swiglu_oai_limit);
// auto alpha = _mm512_set1_ps(k_swiglu_oai_alpha);
// for (; i + 15 < n; i += 16) {
// auto xi = _mm512_loadu_ps(x + i);
// auto mask = _mm512_cmp
//
// }
// __m512 c1 = _mm512_set1_ps(GELU_COEF_A);
// __m512 c2 = _mm512_set1_ps(2.f*SQRT_2_OVER_PI);
// for (; i + 15 < n; i += 16) _mm512_storeu_ps(y + i, v_gelu(_mm512_loadu_ps(x + i), c1, c2));
// }
//#endif
//#if defined __AVX2__ && defined __FMA__
// if (i + 7 < n) {
// __m256 c1 = _mm256_set1_ps(GELU_COEF_A);
// __m256 c2 = _mm256_set1_ps(2.f*SQRT_2_OVER_PI);
// for (; i + 7 < n; i += 8) _mm256_storeu_ps(y + i, v_gelu(_mm256_loadu_ps(x + i), c1, c2));
//
// }
//#endif
// for (; i < n; ++i) {
// auto xi = std::min(x[i], k_swiglu_oai_limit);
// y[i] = xi / (1.0f + expf(-xi * k_swiglu_oai_alpha));
// }
//}
void MulMat::silu(int n, const float * x, float * y) {
int i = 0;
#if defined __AVX512F__ && defined __AVX512DQ__
@@ -1188,6 +1261,8 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const float * sinksf, // mask. If not null, assumed to be fp16. nq x nk elements
[[maybe_unused]] int nsinks,
float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))
@@ -1197,32 +1272,32 @@ bool iqk_flash_attn_impl(int int_type_k, // type of k
if (Dk == 576 && Dv == 512) {
return iqk_fa_576_512(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S);
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 192 && Dv == 128) {
return iqk_fa_192_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S);
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 256 && Dv == 256) {
return iqk_fa_256_256(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S);
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 128 && Dv == 128) {
return iqk_fa_128_128(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S);
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 96 && Dv == 96) {
return iqk_fa_96_96(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S);
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
if (Dk == 64 && Dv == 64) {
return iqk_fa_64_64(int_type_k, int_type_v, nq1, nk1, stride_q, stride_k, stride_v, stride_m, stride_qkv,
q, k, v, mask, scale, softcap, qkv, M, S);
q, k, v, mask, scale, softcap, qkv, sinksf, M, S);
}
return false;

View File

@@ -32,6 +32,7 @@ IQK_API bool iqk_mul_mat_moe(long Nx, long Ny, long ne00, int ne11,
IQK_API bool iqk_moe_fused_up_gate(long Nx, long Ny, long ne00, int ne11, int unary_op,
int typeA, const void * Aup, const void * Agate, long strideA,
int typeB, const void * B, long strideB,
const char * up_b, const char * gate_b,
float * C, long nb1, long nb2, const void * vrow_mapping, int ith, int nth);
IQK_API int iqk_dequant_type(int type, int Ny);
@@ -57,6 +58,7 @@ IQK_API bool iqk_flash_attn_noalibi(int type_q, int type_mask, float max_bias,
const void * k, // k matrix. Assumed to be fp16, nq x nk elements
const void * v, // v matrix. Assumed to be fp16, nq x nk elements
const void * mask, // mask. If not null, assumed to be fp16. nq x nk elements
const void * sinks, // mask. If not null, assumed to be fp16. nq x nk elements
float scale, // scale applied before softmax
float softcap, // if > 0, a "soft-cap" operation is applied before softmax
float * qkv, // v*softmax(scale*(k*q))

View File

@@ -61,6 +61,13 @@ static inline float32x4_t v_silu(float32x4_t x) {
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
static inline float32x4_t v_silu_oai(float32x4_t x, float32x4_t alpha) {
const float32x4_t one = vdupq_n_f32(1.0f);
const float32x4_t neg_x = vmulq_f32(alpha, x);
const float32x4_t exp_neg_x = v_expf(neg_x);
const float32x4_t one_plus_exp_neg_x = vaddq_f32(one, exp_neg_x);
return vdivq_f32(x, one_plus_exp_neg_x);
}
static inline float32x4_t v_gelu(float32x4_t x, float32x4_t c1, float32x4_t c2) {
const float32x4_t one = vdupq_n_f32(1.0f);
float32x4_t arg = vfmaq_f32(one, c1, vmulq_f32(x, x));
@@ -131,6 +138,17 @@ static inline __m512 v_silu(__m512 x) {
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
static inline __m512 v_silu_oai(__m512 x, __m512 alpha) {
const __m512 one = _mm512_set1_ps(1);
const __m512 neg_x = _mm512_mul_ps(alpha, x);
const __m512 exp_neg_x = v_expf(neg_x);
const __m512 one_plus_exp_neg_x = _mm512_add_ps(one, exp_neg_x);
return _mm512_div_ps(x, one_plus_exp_neg_x);
}
static inline __m512 v_clamp_max(__m512 x, __m512 max) {
auto mask = _mm512_cmp_ps_mask(x, max, _CMP_GT_OQ);
return _mm512_mask_blend_ps(mask, x, max);
}
#endif // __AVX512__
#if defined(__AVX2__) && defined(__FMA__)
@@ -195,12 +213,23 @@ static inline __m256 v_gelu(__m256 x, __m256 c1, __m256 c2) {
}
static inline __m256 v_silu(__m256 x) {
const __m256 one = _mm256_set1_ps(1);
const __m256 zero = _mm256_setzero_ps();
const __m256 zero = _mm256_setzero_ps();
const __m256 neg_x = _mm256_sub_ps(zero, x);
const __m256 exp_neg_x = v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
static inline __m256 v_silu_oai(__m256 x, __m256 alpha) {
const __m256 one = _mm256_set1_ps(1);
const __m256 neg_x = _mm256_mul_ps(alpha, x);
const __m256 exp_neg_x = v_expf(neg_x);
const __m256 one_plus_exp_neg_x = _mm256_add_ps(one, exp_neg_x);
return _mm256_div_ps(x, one_plus_exp_neg_x);
}
static inline __m256 v_clamp_max(__m256 x, __m256 max) {
auto mask = _mm256_cmp_ps(x, max, _CMP_GT_OQ);
return _mm256_or_ps(_mm256_and_ps(mask, max), _mm256_andnot_ps(mask, x));
}
#endif // __AVX2__